You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
378
-
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
379
+
with the word token indices in the vocabulary
380
+
(see the tokens preprocessing logic in the scripts
379
381
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
380
-
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
381
-
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
382
+
`token_type_ids`: an optional torch.LongTensor of shape
383
+
[batch_size, sequence_length] with the token types indices selected in [0, 1].
384
+
Type 0 corresponds to a `sentence A` and type 1 corresponds to
382
385
a `sentence B` token (see BERT paper for more details).
383
-
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
384
-
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
385
-
input sequence length in the current batch. It's the mask that we typically use for attention when
386
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length]
387
+
with indices selected in [0, 1]. It's a mask to be used if the input sequence length
388
+
is smaller than the max input sequence length in the current batch.
389
+
It's the mask that we typically use for attention when
386
390
a batch has varying length sentences.
387
-
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
391
+
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers`
392
+
output as described below. Default: `True`.
388
393
389
394
Outputs: Tuple of (encoded_layers, pooled_output)
390
395
`encoded_layers`: controled by `output_all_encoded_layers` argument:
391
-
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
392
-
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
393
-
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
394
-
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
395
-
to the last attention block of shape [batch_size, sequence_length, hidden_size],
396
-
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
397
-
classifier pretrained on top of the hidden state associated to the first character of the
396
+
- `output_all_encoded_layers=True`: outputs a list of the full sequences of
397
+
encoded-hidden-states at the end of each attention block
398
+
(i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
399
+
encoded-hidden-state is a torch.FloatTensor of size
400
+
[batch_size, sequence_length, hidden_size],
401
+
- `output_all_encoded_layers=False`: outputs only the full sequence of
402
+
hidden-states corresponding to the last attention block of shape
403
+
[batch_size, sequence_length, hidden_size],
404
+
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size]
405
+
which is the output of classifier pretrained on top of the hidden state
406
+
associated to the first character of the
398
407
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
399
408
400
409
Example usage:
@@ -474,15 +483,17 @@ class BertForSequenceClassification(Module):
474
483
475
484
Inputs:
476
485
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
477
-
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
486
+
with the word token indices in the vocabulary.
487
+
Items in the batch should begin with the special "CLS" token.
488
+
(see the tokens preprocessing logic in the scripts
478
489
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
479
-
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
480
-
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
481
-
a `sentence B` token (see BERT paper for more details).
482
-
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
483
-
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
484
-
input sequence length in the current batch. It's the mask that we typically use for attention when
485
-
a batch has varying length sentences.
490
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length]
491
+
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
492
+
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
493
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length]
494
+
with indices selected in [0, 1]. It's a mask to be used if the input sequence length
495
+
is smaller than the max input sequence length in the current batch. It's the mask
496
+
that we typically use for attention when a batch has varying length sentences.
486
497
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
0 commit comments