@@ -603,6 +603,7 @@ def __exit__(self, exc_type, exc_value, traceback):
603603 block .self_attention .core_attention .forward = block .self_attention .core_attention ._orig_forward
604604
605605
606+ # what does this patch exactly ?
606607def llama_gemma_rotary_emb_forward (self , x , position_ids , seq_len = None ):
607608 # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104
608609 _seq_len = torch .max (position_ids ) + 1 if seq_len is None else seq_len
@@ -626,27 +627,16 @@ def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_f
626627 return torch .cat ((torch .sin (emb ), torch .cos (emb )), dim = 1 )
627628
628629
629- def register_sin_cos_buffer (model ):
630- max_positions = model .config .max_position_embeddings
631-
632- # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
633- # use precomputed
630+ # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step, use precomputed
631+ def create_embed_positions_buffer (rotary_emb , max_position_embeddings : int = None ):
632+ inv_freq = getattr (rotary_emb , "inv_freq" , None )
634633
635- rotary_emb = model .model .layers [0 ].self_attn .rotary_emb
636634 dim , base = None , None
637- inv_freq = getattr (rotary_emb , "inv_freq" , None )
638635 if inv_freq is None :
639636 base = rotary_emb .base
640637 dim = rotary_emb .dim
641- embed_positions = create_sinusoidal_positions (max_positions , dim , base , inv_freq )
642638
643- for layer in model .model .layers :
644- layer .self_attn .rotary_emb .register_buffer ("embed_positions" , embed_positions )
645- layer .self_attn .rotary_emb ._orig_forward = layer .self_attn .rotary_emb .forward
646-
647- layer .self_attn .rotary_emb .forward = types .MethodType (
648- llama_gemma_rotary_emb_forward , layer .self_attn .rotary_emb
649- )
639+ return create_sinusoidal_positions (max_position_embeddings , dim , base , inv_freq )
650640
651641
652642# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
@@ -768,15 +758,39 @@ def __enter__(self):
768758 self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
769759 self ._model .model ._update_causal_mask = types .MethodType (_mistral_update_causal_mask , self ._model .model )
770760
761+ if (
762+ hasattr (self ._model , "model" )
763+ and hasattr (self ._model .model , "layers" )
764+ and is_transformers_version (">=" , "4.41.0" )
765+ ):
766+ for layer in self ._model .model .layers :
767+ if hasattr (layer .self_attn , "rotary_emb" ):
768+ embed_positions = create_embed_positions_buffer (
769+ rotary_emb = layer .self_attn .rotary_emb ,
770+ max_position_embeddings = self ._model .config .max_position_embeddings ,
771+ )
772+ layer .self_attn .rotary_emb .register_buffer ("embed_positions" , embed_positions )
773+ layer .self_attn .rotary_emb ._orig_forward = layer .self_attn .rotary_emb .forward
774+ layer .self_attn .rotary_emb .forward = types .MethodType (
775+ llama_gemma_rotary_emb_forward , layer .self_attn .rotary_emb
776+ )
777+
771778 def __exit__ (self , exc_type , exc_value , traceback ):
772779 super ().__exit__ (exc_type , exc_value , traceback )
773780
774- if hasattr ( self . _model . model , "_orig_update_causal_mask " ):
781+ if is_transformers_version ( ">=" , "4.42.0" ) and is_transformers_version ( "<" , "4.48.0 " ):
775782 self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
783+ del self ._model .model ._orig_update_causal_mask
776784
777- for layer in self ._model .model .layers :
778- if hasattr (layer .self_attn , "rotary_emb" ) and hasattr (layer .self_attn .rotary_emb , "_orig_forward" ):
779- layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
785+ if (
786+ hasattr (self ._model .model , "model" )
787+ and hasattr (self ._model .model .model , "layers" )
788+ and is_transformers_version (">=" , "4.41.0" )
789+ ):
790+ for layer in self ._model .model .layers :
791+ if hasattr (layer .self_attn , "rotary_emb" ):
792+ layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
793+ del layer .self_attn .rotary_emb ._orig_forward
780794
781795
782796SUPPORT_SDPA = is_torch_version (">" , "2.1.0" )
@@ -4877,7 +4891,6 @@ def __init__(
48774891 # Difference from original:
48784892 # uses Dynamic cache from legacy cache instead of HybridCache
48794893 # calculate causal mask from multimodal
4880- model .__orig_forward = model .forward
48814894
48824895 def forward (
48834896 self , attention_mask , position_ids , past_key_values , token_type_ids , inputs_embeds , use_cache = True
@@ -4913,31 +4926,40 @@ def forward(
49134926 result ["past_key_values" ] = upd_pkv .to_legacy_cache ()
49144927 return result
49154928
4916- model .forward = types .MethodType (forward , model )
4929+ if is_transformers_version ("<" , "4.53.0" ):
4930+ model .__orig_forward = model .forward
4931+ model .forward = types .MethodType (forward , model )
4932+
49174933 super ().__init__ (config , model , model_kwargs )
49184934
49194935 def __enter__ (self ):
49204936 super ().__enter__ ()
49214937
4922- if hasattr (self ._model , "_update_causal_mask_mm" ):
4923- self ._model ._orig_update_causual_mask_mm = self ._model ._update_causal_mask_mm
4938+ if is_transformers_version ("<" , "4.52.0" ):
49244939 self ._model ._update_causal_mask_mm = types .MethodType (_gemma3_mm_update_causal_mask , self ._model )
4925- elif hasattr (self ._model , "model" ) and hasattr (self ._model .model , "_update_causal_mask_mm" ):
4926- self ._model .model ._orig_update_causual_mask_mm = self ._model .model ._update_causal_mask_mm
4927- self ._model .model ._update_causal_mask_mm = types .MethodType (
4928- _gemma3_mm_update_causal_mask , self ._model .model
4929- )
4940+ elif (
4941+ is_transformers_version ("<" , "4.53.0" )
4942+ and hasattr (self ._model , "model" )
4943+ and hasattr (self ._model .model , "_update_causal_mask" )
4944+ ):
4945+ self ._model .model ._orig_update_causual_mask = self ._model .model ._update_causal_mask
4946+ self ._model .model ._update_causal_mask = types .MethodType (_gemma3_mm_update_causal_mask , self ._model .model )
49304947
49314948 def __exit__ (self , exc_type , exc_value , traceback ):
49324949 super ().__exit__ (exc_type , exc_value , traceback )
4933- self ._model .forward = self ._model .__orig_forward
49344950
4935- if hasattr (self ._model , "_orig_update_causual_mask_mm" ):
4936- self ._model ._update_causal_mask_mm = self ._model ._orig_update_causal_mask_mm
4937- del self ._model ._orig_update_causal_mask_mm
4938- elif hasattr (self ._model , "model" ) and hasattr (self ._model .model , "_orig_update_causual_mask_mm" ):
4939- self ._model .model ._update_causal_mask_mm = self ._model .model ._orig_update_causual_mask_mm
4940- del self ._model .model ._orig_update_causual_mask_mm
4951+ if is_transformers_version ("<" , "4.53.0" ):
4952+ self ._model .forward = self ._model .__orig_forward
4953+
4954+ if is_transformers_version ("<" , "4.52" ):
4955+ del self ._update_causal_mask_mm
4956+ elif (
4957+ is_transformers_version ("<" , "4.53.0" )
4958+ and hasattr (self ._model , "model" )
4959+ and hasattr (self ._model .model , "_orig_update_causual_mask" )
4960+ ):
4961+ self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causual_mask
4962+ del self ._model .model ._orig_update_causual_mask
49414963
49424964
49434965class Idefics3ImageEmbeddingsModelPatcher (ModelPatcher ):
0 commit comments