@@ -3051,22 +3051,22 @@ def patched_forward(self, fn):
30513051
30523052    def  __enter__ (self ):
30533053        if  is_torch_version (">=" , "2.1.0" ):
3054-             if  self ._model .config .model_type  ==  "qwen2"  and  self ._model .config ._attn_implementation  !=  "sdpa" :
3055-                 if  is_transformers_version ("<" , "4.48" ):
3054+             if  (
3055+                 self ._model .config .model_type  in  ["qwen2" , "llama" ]
3056+                 and  self ._model .config ._attn_implementation  !=  "sdpa" 
3057+             ):
3058+                 self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
3059+                 self ._model .config ._attn_implementation  =  "sdpa" 
3060+                 if  self ._model .config .model_type  ==  "qwen2"  and  is_transformers_version ("<" , "4.48" ):
30563061                    from  transformers .models .qwen2 .modeling_qwen2  import  QWEN2_ATTENTION_CLASSES 
30573062
30583063                    sdpa_attn  =  QWEN2_ATTENTION_CLASSES ["sdpa" ]
3059-                     self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
3060-                     self ._model .config ._attn_implementation  =  "sdpa" 
30613064
30623065                    for  layer  in  self ._model .model .layers :
30633066                        layer .self_attn ._orig_forward  =  layer .self_attn .forward 
30643067                        layer .self_attn .forward  =  types .MethodType (sdpa_attn .forward , layer .self_attn )
30653068
3066-             if  self ._model .config .model_type  ==  "llama"  and  self ._model .config ._attn_implementation  !=  "sdpa" :
3067-                 self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
3068-                 self ._model .config ._attn_implementation  =  "sdpa" 
3069-                 if  is_transformers_version ("<" , "4.47" ):
3069+                 if  self ._model .config .model_type  ==  "llama"  and  is_transformers_version ("<" , "4.47" ):
30703070                    from  transformers .models .llama .modeling_llama  import  LLAMA_ATTENTION_CLASSES 
30713071
30723072                    sdpa_attn  =  LLAMA_ATTENTION_CLASSES ["sdpa" ]
0 commit comments