File tree Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -410,15 +410,17 @@ def _attn_implementation(self):
410
410
def _attn_implementation (self , value : Optional [Union [str , dict ]]):
411
411
"""We set it recursively on the sub-configs as well"""
412
412
# Set if for current config
413
- attn_implementation = value if not isinstance (value , dict ) else value .get ("" , self ._attn_implementation )
413
+ current_attn = getattr (self , "_attn_implementation" , None )
414
+ attn_implementation = value if not isinstance (value , dict ) else value .get ("" , current_attn )
414
415
self ._attn_implementation_internal = attn_implementation
415
416
416
417
# Set it recursively on the subconfigs
417
418
for subconfig_key in self .sub_configs :
418
419
subconfig = getattr (self , subconfig_key , None )
419
420
if subconfig is not None :
421
+ current_subconfig_attn = getattr (subconfig , "_attn_implementation" , None )
420
422
sub_implementation = (
421
- value if not isinstance (value , dict ) else value .get (subconfig_key , subconfig . _attn_implementation )
423
+ value if not isinstance (value , dict ) else value .get (subconfig_key , current_subconfig_attn )
422
424
)
423
425
subconfig ._attn_implementation = sub_implementation
424
426
Original file line number Diff line number Diff line change @@ -3680,6 +3680,20 @@ def test_attn_implementation_composite_models(self):
3680
3680
model = model_class (config )
3681
3681
self .assertTrue (model .config .get_text_config (decoder = True )._attn_implementation == "eager" )
3682
3682
3683
+ # Test that using `dict` atttention implementation works with `from_pretrained`
3684
+ # Set all backbones to "eager" because "eager" attention is always available
3685
+ with tempfile .TemporaryDirectory () as tmpdirname :
3686
+ model .save_pretrained (tmpdirname )
3687
+ new_model = model .from_pretrained (tmpdirname , attn_implementation = attn_implementation_per_subconfig )
3688
+ self .assertTrue (new_model .config ._attn_implementation == "eager" )
3689
+ for submodule in new_model .modules ():
3690
+ if (
3691
+ submodule is not new_model
3692
+ and isinstance (submodule , PreTrainedModel )
3693
+ and submodule .config .__class__ != new_model .config .__class__
3694
+ ):
3695
+ self .assertTrue (submodule .config ._attn_implementation == "eager" )
3696
+
3683
3697
@require_torch_sdpa
3684
3698
def test_sdpa_can_dispatch_non_composite_models (self ):
3685
3699
"""
You can’t perform that action at this time.
0 commit comments