|
29 | 29 | from transformers.cache_utils import Cache, DynamicCache |
30 | 30 | from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast |
31 | 31 | from transformers.models.llama.configuration_llama import LlamaConfig |
| 32 | +from transformers.masking_utils import create_bidirectional_mask |
32 | 33 | from transformers.models.llama.modeling_llama import LlamaForSequenceClassification, LlamaModel |
33 | 34 | from transformers.processing_utils import Unpack |
34 | 35 | from transformers.utils import TransformersKwargs |
35 | 36 |
|
36 | | -# Check if native create_bidirectional_mask exists (transformers >= 5.0) |
37 | | -try: |
38 | | - from transformers.masking_utils import create_bidirectional_mask |
39 | | - |
40 | | - _HAS_NATIVE_BIDIRECTIONAL_MASK = True |
41 | | -except ImportError: |
42 | | - from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
43 | | - |
44 | | - _HAS_NATIVE_BIDIRECTIONAL_MASK = False |
45 | | - |
46 | 37 | try: |
47 | 38 | from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator |
48 | 39 |
|
@@ -108,42 +99,6 @@ def __init__(self, config: LlamaConfig): |
108 | 99 | for layer in self.layers: |
109 | 100 | layer.self_attn.is_causal = False |
110 | 101 |
|
111 | | - def _create_bidirectional_mask( |
112 | | - self, |
113 | | - input_embeds: torch.Tensor, |
114 | | - attention_mask: Optional[torch.Tensor], |
115 | | - ) -> Optional[torch.Tensor]: |
116 | | - """Create a bidirectional attention mask suitable for the active attention implementation. |
117 | | -
|
118 | | - Args: |
119 | | - input_embeds: Input embeddings (batch_size, seq_len, hidden_size). |
120 | | - attention_mask: 2D padding mask (batch_size, seq_len) with 1 for real |
121 | | - tokens and 0 for padding, or None. |
122 | | -
|
123 | | - Returns: |
124 | | - A 4D float mask for sdpa/eager, a 2D mask for flash_attention_2, |
125 | | - or None when no masking is needed. |
126 | | - """ |
127 | | - if attention_mask is None: |
128 | | - return None |
129 | | - |
130 | | - if _HAS_NATIVE_BIDIRECTIONAL_MASK: |
131 | | - return create_bidirectional_mask( |
132 | | - config=self.config, |
133 | | - input_embeds=input_embeds, |
134 | | - attention_mask=attention_mask, |
135 | | - ) |
136 | | - |
137 | | - # Flash attention handles 2D masks internally; only pass mask if there |
138 | | - # are actually masked tokens (zeros), otherwise return None for efficiency. |
139 | | - if getattr(self.config, "_attn_implementation", None) == "flash_attention_2": |
140 | | - has_masked_tokens = (attention_mask == 0).any() |
141 | | - return attention_mask if has_masked_tokens else None |
142 | | - |
143 | | - # For sdpa / eager: expand to 4D and cast to the model's compute dtype |
144 | | - # so that SDPA receives a float mask matching query dtype. |
145 | | - return _prepare_4d_attention_mask(attention_mask, input_embeds.dtype) |
146 | | - |
147 | 102 | @check_model_inputs |
148 | 103 | def forward( |
149 | 104 | self, |
@@ -174,7 +129,11 @@ def forward( |
174 | 129 | if position_ids is None: |
175 | 130 | position_ids = cache_position.unsqueeze(0) |
176 | 131 |
|
177 | | - bidirectional_mask = self._create_bidirectional_mask(inputs_embeds, attention_mask) |
| 132 | + bidirectional_mask = create_bidirectional_mask( |
| 133 | + config=self.config, |
| 134 | + input_embeds=inputs_embeds, |
| 135 | + attention_mask=attention_mask, |
| 136 | + ) |
178 | 137 |
|
179 | 138 | hidden_states = inputs_embeds |
180 | 139 | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
0 commit comments