Skip to content

Commit 2f1a8ad

Browse files
authored
Fix setting attention for multimodal models (#39984)
* fix * use non-explicit `None` * keep previously set attn if exists
1 parent a2e76b9 commit 2f1a8ad

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/transformers/configuration_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,17 @@ def _attn_implementation(self):
410410
def _attn_implementation(self, value: Optional[Union[str, dict]]):
411411
"""We set it recursively on the sub-configs as well"""
412412
# Set if for current config
413-
attn_implementation = value if not isinstance(value, dict) else value.get("", self._attn_implementation)
413+
current_attn = getattr(self, "_attn_implementation", None)
414+
attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
414415
self._attn_implementation_internal = attn_implementation
415416

416417
# Set it recursively on the subconfigs
417418
for subconfig_key in self.sub_configs:
418419
subconfig = getattr(self, subconfig_key, None)
419420
if subconfig is not None:
421+
current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
420422
sub_implementation = (
421-
value if not isinstance(value, dict) else value.get(subconfig_key, subconfig._attn_implementation)
423+
value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
422424
)
423425
subconfig._attn_implementation = sub_implementation
424426

tests/test_modeling_common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3680,6 +3680,20 @@ def test_attn_implementation_composite_models(self):
36803680
model = model_class(config)
36813681
self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager")
36823682

3683+
# Test that using `dict` atttention implementation works with `from_pretrained`
3684+
# Set all backbones to "eager" because "eager" attention is always available
3685+
with tempfile.TemporaryDirectory() as tmpdirname:
3686+
model.save_pretrained(tmpdirname)
3687+
new_model = model.from_pretrained(tmpdirname, attn_implementation=attn_implementation_per_subconfig)
3688+
self.assertTrue(new_model.config._attn_implementation == "eager")
3689+
for submodule in new_model.modules():
3690+
if (
3691+
submodule is not new_model
3692+
and isinstance(submodule, PreTrainedModel)
3693+
and submodule.config.__class__ != new_model.config.__class__
3694+
):
3695+
self.assertTrue(submodule.config._attn_implementation == "eager")
3696+
36833697
@require_torch_sdpa
36843698
def test_sdpa_can_dispatch_non_composite_models(self):
36853699
"""

0 commit comments

Comments
 (0)