Skip to content

Commit 694aaa7

Browse files
Fix how we compute the final non-padding token for ForSequenceClassification models (#35911)
* Fix how we compute the final non-padding token for Gemma (and probably other models) * .size() -> .shape[] * Propagating changes to other models * Propagating changes to other models * Change it for all ForSequenceClassification models * Fix batch dim * More TF fixes * Copy the TF fix around as well * Correct layer name for TFCTRL * Cleaner .to() * Clean up the nested if-else * Use argmax() instead of .max().values
1 parent 531d151 commit 694aaa7

37 files changed

+448
-437
lines changed

src/transformers/models/bloom/modeling_bloom.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,21 +1127,20 @@ def forward(
11271127
if self.config.pad_token_id is None and batch_size != 1:
11281128
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
11291129
if self.config.pad_token_id is None:
1130-
sequence_lengths = -1
1130+
last_non_pad_token = -1
1131+
elif input_ids is not None:
1132+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1133+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1134+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1135+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
11311136
else:
1132-
if input_ids is not None:
1133-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1134-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1135-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
1136-
sequence_lengths = sequence_lengths.to(logits.device)
1137-
else:
1138-
sequence_lengths = -1
1139-
logger.warning_once(
1140-
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1141-
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1142-
)
1137+
last_non_pad_token = -1
1138+
logger.warning_once(
1139+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1140+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1141+
)
11431142

1144-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1143+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
11451144

11461145
loss = None
11471146
if labels is not None:

src/transformers/models/ctrl/modeling_ctrl.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -789,23 +789,21 @@ def forward(
789789

790790
if self.config.pad_token_id is None and batch_size != 1:
791791
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
792-
793792
if self.config.pad_token_id is None:
794-
sequence_lengths = -1
793+
last_non_pad_token = -1
794+
elif input_ids is not None:
795+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
796+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
797+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
798+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
795799
else:
796-
if input_ids is not None:
797-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
798-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
799-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
800-
sequence_lengths = sequence_lengths.to(logits.device)
801-
else:
802-
sequence_lengths = -1
803-
logger.warning_once(
804-
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
805-
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
806-
)
800+
last_non_pad_token = -1
801+
logger.warning_once(
802+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
803+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
804+
)
807805

808-
pooled_logits = logits[range(batch_size), sequence_lengths]
806+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
809807

810808
loss = None
811809
if labels is not None:

src/transformers/models/ctrl/modeling_tf_ctrl.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -868,42 +868,33 @@ def call(
868868
return_dict=return_dict,
869869
training=training,
870870
)
871-
872871
hidden_states = transformer_outputs[0]
873872
logits = self.classifier(hidden_states)
874-
in_logits = None
873+
logits_shape = shape_list(logits)
874+
batch_size = logits_shape[0]
875+
875876
if self.config.pad_token_id is None:
876-
sequence_lengths = -1
877+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
877878
else:
878879
if input_ids is not None:
879-
sequence_lengths = (
880-
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
881-
- 1
882-
)
883-
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
884-
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
880+
token_indices = tf.range(shape_list(input_ids)[-1])
881+
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
882+
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
885883
else:
886-
sequence_lengths = -1
884+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
887885
logger.warning_once(
888886
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
889887
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
890888
)
891889
loss = None
892890

891+
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
892+
893893
if labels is not None:
894-
if input_ids is not None:
895-
batch_size, sequence_length = shape_list(input_ids)[:2]
896-
else:
897-
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
898-
if self.config.pad_token_id is None and batch_size != 1:
894+
if self.config.pad_token_id is None and logits_shape[0] != 1:
899895
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
900896

901-
if not tf.is_tensor(sequence_lengths):
902-
in_logits = logits[0:batch_size, sequence_lengths]
903-
904-
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
905-
906-
pooled_logits = in_logits if in_logits is not None else logits
897+
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
907898

908899
if not return_dict:
909900
output = (pooled_logits,) + transformer_outputs[1:]

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,17 +1216,20 @@ def forward(
12161216
if self.config.pad_token_id is None and batch_size != 1:
12171217
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
12181218
if self.config.pad_token_id is None:
1219-
sequence_lengths = -1
1219+
last_non_pad_token = -1
1220+
elif input_ids is not None:
1221+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1222+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1223+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1224+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
12201225
else:
1221-
if input_ids is not None:
1222-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1223-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1224-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
1225-
sequence_lengths = sequence_lengths.to(logits.device)
1226-
else:
1227-
sequence_lengths = -1
1226+
last_non_pad_token = -1
1227+
logger.warning_once(
1228+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1229+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1230+
)
12281231

1229-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1232+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
12301233

12311234
loss = None
12321235
if labels is not None:

src/transformers/models/falcon/modeling_falcon.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,21 +1359,20 @@ def forward(
13591359
if self.config.pad_token_id is None and batch_size != 1:
13601360
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
13611361
if self.config.pad_token_id is None:
1362-
sequence_lengths = -1
1362+
last_non_pad_token = -1
1363+
elif input_ids is not None:
1364+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1365+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1366+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1367+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
13631368
else:
1364-
if input_ids is not None:
1365-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1366-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1367-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
1368-
sequence_lengths = sequence_lengths.to(logits.device)
1369-
else:
1370-
sequence_lengths = -1
1371-
logger.warning_once(
1372-
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1373-
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1374-
)
1369+
last_non_pad_token = -1
1370+
logger.warning_once(
1371+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1372+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1373+
)
13751374

1376-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1375+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
13771376

13781377
loss = None
13791378
if labels is not None:

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -948,17 +948,20 @@ def forward(
948948
if self.config.pad_token_id is None and batch_size != 1:
949949
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
950950
if self.config.pad_token_id is None:
951-
sequence_lengths = -1
951+
last_non_pad_token = -1
952+
elif input_ids is not None:
953+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
954+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
955+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
956+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
952957
else:
953-
if input_ids is not None:
954-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
955-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
956-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
957-
sequence_lengths = sequence_lengths.to(logits.device)
958-
else:
959-
sequence_lengths = -1
958+
last_non_pad_token = -1
959+
logger.warning_once(
960+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
961+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
962+
)
960963

961-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
964+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
962965

963966
loss = None
964967
if labels is not None:

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,17 +1083,20 @@ def forward(
10831083
if self.config.pad_token_id is None and batch_size != 1:
10841084
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
10851085
if self.config.pad_token_id is None:
1086-
sequence_lengths = -1
1086+
last_non_pad_token = -1
1087+
elif input_ids is not None:
1088+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1089+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1090+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1091+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
10871092
else:
1088-
if input_ids is not None:
1089-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1090-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1091-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
1092-
sequence_lengths = sequence_lengths.to(logits.device)
1093-
else:
1094-
sequence_lengths = -1
1093+
last_non_pad_token = -1
1094+
logger.warning_once(
1095+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1096+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1097+
)
10951098

1096-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1099+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
10971100

10981101
loss = None
10991102
if labels is not None:

src/transformers/models/glm/modeling_glm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -958,17 +958,20 @@ def forward(
958958
if self.config.pad_token_id is None and batch_size != 1:
959959
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
960960
if self.config.pad_token_id is None:
961-
sequence_lengths = -1
961+
last_non_pad_token = -1
962+
elif input_ids is not None:
963+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
964+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
965+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
966+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
962967
else:
963-
if input_ids is not None:
964-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
965-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
966-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
967-
sequence_lengths = sequence_lengths.to(logits.device)
968-
else:
969-
sequence_lengths = -1
968+
last_non_pad_token = -1
969+
logger.warning_once(
970+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
971+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
972+
)
970973

971-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
974+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
972975

973976
loss = None
974977
if labels is not None:

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,25 +1393,23 @@ def forward(
13931393
else:
13941394
batch_size, sequence_length = inputs_embeds.shape[:2]
13951395

1396-
assert (
1397-
self.config.pad_token_id is not None or batch_size == 1
1398-
), "Cannot handle batch sizes > 1 if no padding token is defined."
1396+
if self.config.pad_token_id is None and batch_size != 1:
1397+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
13991398
if self.config.pad_token_id is None:
1400-
sequence_lengths = -1
1399+
last_non_pad_token = -1
1400+
elif input_ids is not None:
1401+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1402+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1403+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1404+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
14011405
else:
1402-
if input_ids is not None:
1403-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1404-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1405-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
1406-
sequence_lengths = sequence_lengths.to(logits.device)
1407-
else:
1408-
sequence_lengths = -1
1409-
logger.warning_once(
1410-
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1411-
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1412-
)
1406+
last_non_pad_token = -1
1407+
logger.warning_once(
1408+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1409+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1410+
)
14131411

1414-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1412+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
14151413

14161414
loss = None
14171415
if labels is not None:

0 commit comments

Comments
 (0)