@@ -718,14 +718,15 @@ def _mistral_update_causal_mask(
718718class  MistralModelPatcher (DecoderModelPatcher ):
719719    def  __enter__ (self ):
720720        super ().__enter__ ()
721-         if  is_transformers_version (">=" , "4.42.0" ):
721+         if  is_transformers_version (">=" , "4.42.0" )  and   is_transformers_version ( "<" ,  "4.48.0" ) :
722722            # apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 
723723            self ._model .model ._orig_update_causal_mask  =  self ._model .model ._update_causal_mask 
724724            self ._model .model ._update_causal_mask  =  types .MethodType (_mistral_update_causal_mask , self ._model .model )
725725
726726        else :
727727            for  layer  in  self ._model .model .layers :
728-                 _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
728+                 if  hasattr (layer .self_attn , "rotary_emb" ):
729+                     _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
729730
730731    def  __exit__ (self , exc_type , exc_value , traceback ):
731732        super ().__exit__ (exc_type , exc_value , traceback )
@@ -734,7 +735,7 @@ def __exit__(self, exc_type, exc_value, traceback):
734735            self ._model .model ._update_causal_mask  =  self ._model .model ._orig_update_causal_mask 
735736
736737        for  layer  in  self ._model .model .layers :
737-             if  hasattr (layer .self_attn .rotary_emb , "_orig_forward" ):
738+             if  hasattr (layer .self_attn ,  "rotary_emb" )  and   hasattr ( layer . self_attn .rotary_emb , "_orig_forward" ):
738739                layer .self_attn .rotary_emb .forward  =  layer .self_attn .rotary_emb ._orig_forward 
739740
740741
@@ -1580,19 +1581,19 @@ def __enter__(self):
15801581        ):
15811582            self ._model .config .max_position_embeddings  =  self ._model .config .original_max_position_embeddings 
15821583
1583-         if  is_transformers_version (">=" , "4.42.0" ):
1584+         if  is_transformers_version (">=" , "4.42.0" )  and   is_transformers_version ( "<" ,  "4.48.0" ) :
15841585            self ._model .model ._orig_forward  =  self ._model .model .forward 
15851586            self ._model .model .forward  =  types .MethodType (phi3_442_forward , self ._model .model )
15861587
15871588        # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 
15881589        # init inv_freq for torchscript tracing 
15891590        for  layer  in  self ._model .model .layers :
1590-             if  is_torch_version (">=" , "2.1.0" ):
1591+             if  is_torch_version (">=" , "2.1.0" )  and   is_transformers_version ( "<" ,  "4.48.0" ) :
15911592                orig_self_attn_fwd  =  layer .self_attn .forward 
15921593                layer .self_attn .forward  =  types .MethodType (_phi3_self_attn_sdpa_forward , layer .self_attn )
15931594                layer .self_attn ._orig_forward  =  orig_self_attn_fwd 
15941595
1595-             if  layer .self_attn .rotary_emb .inv_freq  is  None :
1596+             if  hasattr ( layer . self_attn ,  "rotary_emb" )  and   layer .self_attn .rotary_emb .inv_freq  is  None :
15961597                rotary_emb  =  layer .self_attn .rotary_emb 
15971598                layer .self_attn .rotary_emb .inv_freq  =  1.0  /  (
15981599                    rotary_emb .base  **  (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () /  rotary_emb .dim )
@@ -2493,7 +2494,9 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
24932494    def  __enter__ (self ):
24942495        super ().__enter__ ()
24952496        patch_update_causal_mask (self ._model , "4.42.0" )
2496-         if  hasattr (self ._model .model .layers [0 ].self_attn .rotary_emb , "_set_cos_sin_cache" ):
2497+         if  hasattr (self ._model .model .layers [0 ].self_attn , "rotary_emb" ) and  hasattr (
2498+             self ._model .model .layers [0 ].self_attn .rotary_emb , "_set_cos_sin_cache" 
2499+         ):
24972500            for  layer  in  self ._model .model .layers :
24982501                _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
24992502
@@ -3045,15 +3048,16 @@ def patched_forward(self, fn):
30453048    def  __enter__ (self ):
30463049        if  is_torch_version (">=" , "2.1.0" ):
30473050            if  self ._model .config .model_type  ==  "qwen2"  and  self ._model .config ._attn_implementation  !=  "sdpa" :
3048-                 from  transformers .models .qwen2 .modeling_qwen2  import  QWEN2_ATTENTION_CLASSES 
3051+                 if  is_transformers_version ("<" , "4.48" ):
3052+                     from  transformers .models .qwen2 .modeling_qwen2  import  QWEN2_ATTENTION_CLASSES 
30493053
3050-                 sdpa_attn  =  QWEN2_ATTENTION_CLASSES ["sdpa" ]
3051-                 self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
3052-                 self ._model .config ._attn_implementation  =  "sdpa" 
3054+                      sdpa_attn  =  QWEN2_ATTENTION_CLASSES ["sdpa" ]
3055+                      self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
3056+                      self ._model .config ._attn_implementation  =  "sdpa" 
30533057
3054-                 for  layer  in  self ._model .model .layers :
3055-                     layer .self_attn ._orig_forward  =  layer .self_attn .forward 
3056-                     layer .self_attn .forward  =  types .MethodType (sdpa_attn .forward , layer .self_attn )
3058+                      for  layer  in  self ._model .model .layers :
3059+                          layer .self_attn ._orig_forward  =  layer .self_attn .forward 
3060+                          layer .self_attn .forward  =  types .MethodType (sdpa_attn .forward , layer .self_attn )
30573061
30583062            if  self ._model .config .model_type  ==  "llama"  and  self ._model .config ._attn_implementation  !=  "sdpa" :
30593063                self ._model .config ._orig_attn_implementation  =  self ._model .config ._attn_implementation 
0 commit comments