Skip to content

Commit 3c9dbe9

Browse files
ashmikuzNielsRogge
authored andcommitted
[ModernBert] Prevent the attention mask from being None in ModernBertForSequenceClassification (huggingface#35991)
* [ModernBert] Prevent the attention mask from being None in ModernBertForSequenceClassification * fix the modular conversion
1 parent ff829a2 commit 3c9dbe9

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/transformers/models/modernbert/modeling_modernbert.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,19 @@ def forward(
11511151
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
11521152
self._maybe_set_compile()
11531153

1154+
if input_ids is not None:
1155+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1156+
1157+
if batch_size is None and seq_len is None:
1158+
if inputs_embeds is not None:
1159+
batch_size, seq_len = inputs_embeds.shape[:2]
1160+
else:
1161+
batch_size, seq_len = input_ids.shape[:2]
1162+
device = input_ids.device if input_ids is not None else inputs_embeds.device
1163+
1164+
if attention_mask is None:
1165+
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1166+
11541167
outputs = self.model(
11551168
input_ids=input_ids,
11561169
attention_mask=attention_mask,

src/transformers/models/modernbert/modular_modernbert.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,19 @@ def forward(
12771277
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
12781278
self._maybe_set_compile()
12791279

1280+
if input_ids is not None:
1281+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1282+
1283+
if batch_size is None and seq_len is None:
1284+
if inputs_embeds is not None:
1285+
batch_size, seq_len = inputs_embeds.shape[:2]
1286+
else:
1287+
batch_size, seq_len = input_ids.shape[:2]
1288+
device = input_ids.device if input_ids is not None else inputs_embeds.device
1289+
1290+
if attention_mask is None:
1291+
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1292+
12801293
outputs = self.model(
12811294
input_ids=input_ids,
12821295
attention_mask=attention_mask,

0 commit comments

Comments
 (0)