Skip to content

Commit 7512f73

Browse files
Assume transformers >=5
Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
1 parent 2cfea09 commit 7512f73

File tree

1 file changed

+6
-47
lines changed
  • nemo_automodel/components/models/llama_bidirectional

1 file changed

+6
-47
lines changed

nemo_automodel/components/models/llama_bidirectional/model.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,11 @@
2929
from transformers.cache_utils import Cache, DynamicCache
3030
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
3131
from transformers.models.llama.configuration_llama import LlamaConfig
32+
from transformers.masking_utils import create_bidirectional_mask
3233
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification, LlamaModel
3334
from transformers.processing_utils import Unpack
3435
from transformers.utils import TransformersKwargs
3536

36-
# Check if native create_bidirectional_mask exists (transformers >= 5.0)
37-
try:
38-
from transformers.masking_utils import create_bidirectional_mask
39-
40-
_HAS_NATIVE_BIDIRECTIONAL_MASK = True
41-
except ImportError:
42-
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
43-
44-
_HAS_NATIVE_BIDIRECTIONAL_MASK = False
45-
4637
try:
4738
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
4839

@@ -108,42 +99,6 @@ def __init__(self, config: LlamaConfig):
10899
for layer in self.layers:
109100
layer.self_attn.is_causal = False
110101

111-
def _create_bidirectional_mask(
112-
self,
113-
input_embeds: torch.Tensor,
114-
attention_mask: Optional[torch.Tensor],
115-
) -> Optional[torch.Tensor]:
116-
"""Create a bidirectional attention mask suitable for the active attention implementation.
117-
118-
Args:
119-
input_embeds: Input embeddings (batch_size, seq_len, hidden_size).
120-
attention_mask: 2D padding mask (batch_size, seq_len) with 1 for real
121-
tokens and 0 for padding, or None.
122-
123-
Returns:
124-
A 4D float mask for sdpa/eager, a 2D mask for flash_attention_2,
125-
or None when no masking is needed.
126-
"""
127-
if attention_mask is None:
128-
return None
129-
130-
if _HAS_NATIVE_BIDIRECTIONAL_MASK:
131-
return create_bidirectional_mask(
132-
config=self.config,
133-
input_embeds=input_embeds,
134-
attention_mask=attention_mask,
135-
)
136-
137-
# Flash attention handles 2D masks internally; only pass mask if there
138-
# are actually masked tokens (zeros), otherwise return None for efficiency.
139-
if getattr(self.config, "_attn_implementation", None) == "flash_attention_2":
140-
has_masked_tokens = (attention_mask == 0).any()
141-
return attention_mask if has_masked_tokens else None
142-
143-
# For sdpa / eager: expand to 4D and cast to the model's compute dtype
144-
# so that SDPA receives a float mask matching query dtype.
145-
return _prepare_4d_attention_mask(attention_mask, input_embeds.dtype)
146-
147102
@check_model_inputs
148103
def forward(
149104
self,
@@ -174,7 +129,11 @@ def forward(
174129
if position_ids is None:
175130
position_ids = cache_position.unsqueeze(0)
176131

177-
bidirectional_mask = self._create_bidirectional_mask(inputs_embeds, attention_mask)
132+
bidirectional_mask = create_bidirectional_mask(
133+
config=self.config,
134+
input_embeds=inputs_embeds,
135+
attention_mask=attention_mask,
136+
)
178137

179138
hidden_states = inputs_embeds
180139
position_embeddings = self.rotary_emb(hidden_states, position_ids)

0 commit comments

Comments
 (0)