Skip to content

Commit f5d45d8

Browse files
🚨Early-error🚨 config will error out if output_attentions=True and the attn implementation is wrong (#38288)
* Protect ParallelInterface * early error out on output attention setting for no wraning in modeling * modular update * fixup * update model tests * update * oups * set model's config * more cases * ?? * properly fix * fixup * update * last onces * update * fix? * fix wrong merge commit * fix hub test * nits * wow I am tired * updates * fix pipeline! --------- Co-authored-by: Lysandre <[email protected]>
1 parent 896833c commit f5d45d8

File tree

71 files changed

+157
-144
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+157
-144
lines changed

‎docs/source/en/model_doc/jamba.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ quantization_config = BitsAndBytesConfig(load_in_8bit=True,
9999
device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7}
100100
model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6",
101101
torch_dtype=torch.bfloat16,
102-
attn_implementation="flash_attention_2",
102+
attn_implementation="flash_attention_2",
103103
quantization_config=quantization_config,
104104
device_map=device_map)
105105

‎src/transformers/configuration_utils.py‎

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __init__(self, **kwargs):
214214
# Attributes with defaults
215215
self.return_dict = kwargs.pop("return_dict", True)
216216
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
217-
self.output_attentions = kwargs.pop("output_attentions", False)
217+
self._output_attentions = kwargs.pop("output_attentions", False)
218218
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
219219
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
220220
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
@@ -331,6 +331,22 @@ def name_or_path(self) -> str:
331331
def name_or_path(self, value):
332332
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
333333

334+
@property
335+
def output_attentions(self):
336+
"""
337+
`bool`: Whether or not the model should returns all attentions.
338+
"""
339+
return self._output_attentions
340+
341+
@output_attentions.setter
342+
def output_attentions(self, value):
343+
if self._attn_implementation != "eager":
344+
raise ValueError(
345+
"The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
346+
f"{self._attn_implementation}. Please set it to 'eager' instead."
347+
)
348+
self._output_attentions = value
349+
334350
@property
335351
def use_return_dict(self) -> bool:
336352
"""
@@ -1004,6 +1020,8 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
10041020

10051021
if "_auto_class" in d:
10061022
del d["_auto_class"]
1023+
if "_output_attentions" in d:
1024+
d["output_attentions"] = d.pop("_output_attentions")
10071025
if "_commit_hash" in d:
10081026
del d["_commit_hash"]
10091027
if "_attn_implementation_internal" in d:

‎src/transformers/models/aria/modeling_aria.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -549,15 +549,8 @@ def forward(
549549
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
550550

551551
attention_interface: Callable = eager_attention_forward
552-
553552
if self.config._attn_implementation != "eager":
554-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
555-
logger.warning_once(
556-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
557-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
558-
)
559-
else:
560-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
553+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
561554

562555
attn_output, attn_weights = attention_interface(
563556
self,

‎src/transformers/models/bamba/modeling_bamba.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,8 @@ def forward(
313313
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
314314

315315
attention_interface: Callable = eager_attention_forward
316-
317316
if self.config._attn_implementation != "eager":
318-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
319-
logger.warning_once(
320-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
321-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
322-
)
323-
else:
324-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
317+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
325318

326319
attn_output, attn_weights = attention_interface(
327320
self,

‎src/transformers/models/csm/modeling_csm.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,8 @@ def forward(
337337
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
338338

339339
attention_interface: Callable = eager_attention_forward
340-
341340
if self.config._attn_implementation != "eager":
342-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
343-
logger.warning_once(
344-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
345-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
346-
)
347-
else:
348-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
341+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
349342

350343
attn_output, attn_weights = attention_interface(
351344
self,

‎src/transformers/models/emu3/modeling_emu3.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,8 @@ def forward(
206206
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
207207

208208
attention_interface: Callable = eager_attention_forward
209-
210209
if self.config._attn_implementation != "eager":
211-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
212-
logger.warning_once(
213-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
214-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
215-
)
216-
else:
217-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
210+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
218211

219212
attn_output, attn_weights = attention_interface(
220213
self,

‎src/transformers/models/gemma/modeling_gemma.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,8 @@ def forward(
239239
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
240240

241241
attention_interface: Callable = eager_attention_forward
242-
243242
if self.config._attn_implementation != "eager":
244-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
245-
logger.warning_once(
246-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
247-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
248-
)
249-
else:
250-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
243+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
251244

252245
attn_output, attn_weights = attention_interface(
253246
self,

‎src/transformers/models/glm/modeling_glm.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,8 @@ def forward(
201201
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
202202

203203
attention_interface: Callable = eager_attention_forward
204-
205204
if self.config._attn_implementation != "eager":
206-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
207-
logger.warning_once(
208-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
209-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
210-
)
211-
else:
212-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
205+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
213206

214207
attn_output, attn_weights = attention_interface(
215208
self,

‎src/transformers/models/glm4/modeling_glm4.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,8 @@ def forward(
259259
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
260260

261261
attention_interface: Callable = eager_attention_forward
262-
263262
if self.config._attn_implementation != "eager":
264-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
265-
logger.warning_once(
266-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
267-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
268-
)
269-
else:
270-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
263+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
271264

272265
attn_output, attn_weights = attention_interface(
273266
self,

‎src/transformers/models/granite/modeling_granite.py‎

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,8 @@ def forward(
165165
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
166166

167167
attention_interface: Callable = eager_attention_forward
168-
169168
if self.config._attn_implementation != "eager":
170-
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
171-
logger.warning_once(
172-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
173-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
174-
)
175-
else:
176-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
169+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
177170

178171
attn_output, attn_weights = attention_interface(
179172
self,

0 commit comments

Comments
 (0)