@@ -2803,7 +2803,6 @@ def patched_forward(*args, **kwargs):
28032803
28042804 signature = inspect .signature (self .orig_forward )
28052805 args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = self .model_kwargs )
2806-
28072806 return_legacy_cache = False
28082807 pkv_in_args = False
28092808 legacy_pkv = None
@@ -4407,7 +4406,7 @@ def __init__(
44074406 super ().__init__ (config , model , model_kwargs )
44084407
44094408
4410- class GotOCR2ImageEmbeddingsModelPatcher (ModelPatcher ):
4409+ class CommonImageEmbeddingsModelPatcher (ModelPatcher ):
44114410 def __init__ (
44124411 self ,
44134412 config : "OnnxConfig" ,
@@ -4416,9 +4415,107 @@ def __init__(
44164415 ):
44174416 model .__orig_forward = model .forward
44184417 # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4418+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
44194419 model .forward = model .get_image_features
44204420 super ().__init__ (config , model , model_kwargs )
44214421
44224422 def __exit__ (self , exc_type , exc_value , traceback ):
44234423 super ().__exit__ (exc_type , exc_value , traceback )
44244424 self ._model .forward = self ._model .__orig_forward
4425+
4426+
4427+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
4428+ def _gemma3_mm_update_causal_mask (
4429+ self , attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training : bool = False
4430+ ):
4431+ if attention_mask is not None and attention_mask .dim () == 4 :
4432+ # In this case we assume that the mask comes already in inverted
4433+ # form and requires no inversion or slicing.
4434+ return attention_mask
4435+
4436+ min_dtype = torch .finfo (torch .float16 ).min
4437+ inputs_lead_dim , sequence_length = input_tensor .shape [:2 ]
4438+ target_length = (
4439+ attention_mask .shape [- 1 ]
4440+ if isinstance (attention_mask , torch .Tensor )
4441+ else cache_position [0 ] + sequence_length + 1
4442+ )
4443+
4444+ causal_mask = torch .full (
4445+ (sequence_length , target_length ), fill_value = min_dtype , dtype = self .dtype , device = cache_position .device
4446+ )
4447+
4448+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
4449+ if sequence_length != 1 :
4450+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
4451+
4452+ causal_mask *= torch .arange (target_length , device = cache_position .device ) > cache_position .reshape (- 1 , 1 )
4453+ causal_mask = causal_mask [None , None , :, :].expand (inputs_lead_dim , 1 , - 1 , - 1 )
4454+
4455+ # Apply bidirectional mask on images if token type ids are provided
4456+ if token_type_ids is not None and sequence_length != 1 :
4457+ token_type_mask = token_type_ids .unsqueeze (1 ) == token_type_ids .unsqueeze (2 )
4458+ token_type_mask [token_type_ids == 0 ] = False # if text token do not change anything
4459+ token_type_mask = token_type_mask .unsqueeze (1 ).to (causal_mask .device , dtype = torch .bool )
4460+ causal_mask = causal_mask .clone ()
4461+ causal_mask [:, :, :, :sequence_length ] = causal_mask [:, :, :, :sequence_length ].masked_fill (
4462+ token_type_mask , 0.0
4463+ )
4464+
4465+ if attention_mask is not None :
4466+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
4467+ mask_length = attention_mask .shape [- 1 ]
4468+
4469+ # Then apply padding mask (will mask pad tokens)
4470+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (causal_mask .device )
4471+ padding_mask = padding_mask == 0
4472+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (padding_mask , min_dtype )
4473+
4474+ return causal_mask
4475+
4476+
4477+ class Gemma3LMModelPatcher (DecoderModelPatcher ):
4478+ def __init__ (
4479+ self ,
4480+ config : "OnnxConfig" ,
4481+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
4482+ model_kwargs : Optional [Dict [str , Any ]] = None ,
4483+ ):
4484+ model .__orig_forward = model .forward
4485+ model ._update_causal_mask_mm = types .MethodType (_gemma3_mm_update_causal_mask , model )
4486+
4487+ # Difference from original:
4488+ # uses Dynamic cache from legacy cache instead of HybridCache
4489+ # calculate causal mask from multimodal
4490+ def forward (self , attention_mask , position_ids , past_key_values , token_type_ids , inputs_embeds ):
4491+ from transformers .cache_utils import DynamicCache
4492+
4493+ pkv = DynamicCache .from_legacy_cache (past_key_values )
4494+
4495+ past_seen_tokens = past_key_values [0 ][0 ].shape [- 2 ]
4496+ cache_position = torch .arange (
4497+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
4498+ )
4499+
4500+ causal_mask = self ._update_causal_mask_mm (
4501+ attention_mask , token_type_ids , past_key_values , cache_position , inputs_embeds
4502+ )
4503+
4504+ result = self .__orig_forward (
4505+ input_ids = None ,
4506+ attention_mask = causal_mask ,
4507+ position_ids = position_ids ,
4508+ cache_position = cache_position ,
4509+ past_key_values = pkv ,
4510+ inputs_embeds = inputs_embeds ,
4511+ )
4512+ upd_pkv = result ["past_key_values" ]
4513+ result ["past_key_values" ] = upd_pkv .to_legacy_cache ()
4514+ return result
4515+
4516+ model .forward = types .MethodType (forward , model )
4517+ super ().__init__ (config , model , model_kwargs )
4518+
4519+ def __exit__ (self , exc_type , exc_value , traceback ):
4520+ super ().__exit__ (exc_type , exc_value , traceback )
4521+ self ._model .forward = self ._model .__orig_forward
0 commit comments