Skip to content

Commit b07924b

Browse files
IlyasMoutawwakilmvafin
authored andcommitted
Fix gemma3 and llava patches for transformers 4.52 (#1408)
* fix gemma3 * fix mistral patch * added test for llava next mistral * update test repo id * add version check
1 parent 6803d1e commit b07924b

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ def __exit__(self, exc_type, exc_value, traceback):
603603
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward
604604

605605

606+
# what does this patch exactly ?
606607
def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
607608
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104
608609
_seq_len = torch.max(position_ids) + 1 if seq_len is None else seq_len
@@ -626,27 +627,16 @@ def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_f
626627
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
627628

628629

629-
def register_sin_cos_buffer(model):
630-
max_positions = model.config.max_position_embeddings
631-
632-
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
633-
# use precomputed
630+
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step, use precomputed
631+
def create_embed_positions_buffer(rotary_emb, max_position_embeddings: int = None):
632+
inv_freq = getattr(rotary_emb, "inv_freq", None)
634633

635-
rotary_emb = model.model.layers[0].self_attn.rotary_emb
636634
dim, base = None, None
637-
inv_freq = getattr(rotary_emb, "inv_freq", None)
638635
if inv_freq is None:
639636
base = rotary_emb.base
640637
dim = rotary_emb.dim
641-
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
642638

643-
for layer in model.model.layers:
644-
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
645-
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
646-
647-
layer.self_attn.rotary_emb.forward = types.MethodType(
648-
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
649-
)
639+
return create_sinusoidal_positions(max_position_embeddings, dim, base, inv_freq)
650640

651641

652642
# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
@@ -768,15 +758,39 @@ def __enter__(self):
768758
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
769759
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)
770760

761+
if (
762+
hasattr(self._model, "model")
763+
and hasattr(self._model.model, "layers")
764+
and is_transformers_version(">=", "4.41.0")
765+
):
766+
for layer in self._model.model.layers:
767+
if hasattr(layer.self_attn, "rotary_emb"):
768+
embed_positions = create_embed_positions_buffer(
769+
rotary_emb=layer.self_attn.rotary_emb,
770+
max_position_embeddings=self._model.config.max_position_embeddings,
771+
)
772+
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
773+
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
774+
layer.self_attn.rotary_emb.forward = types.MethodType(
775+
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
776+
)
777+
771778
def __exit__(self, exc_type, exc_value, traceback):
772779
super().__exit__(exc_type, exc_value, traceback)
773780

774-
if hasattr(self._model.model, "_orig_update_causal_mask"):
781+
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
775782
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
783+
del self._model.model._orig_update_causal_mask
776784

777-
for layer in self._model.model.layers:
778-
if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
779-
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
785+
if (
786+
hasattr(self._model.model, "model")
787+
and hasattr(self._model.model.model, "layers")
788+
and is_transformers_version(">=", "4.41.0")
789+
):
790+
for layer in self._model.model.layers:
791+
if hasattr(layer.self_attn, "rotary_emb"):
792+
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
793+
del layer.self_attn.rotary_emb._orig_forward
780794

781795

782796
SUPPORT_SDPA = is_torch_version(">", "2.1.0")
@@ -4877,7 +4891,6 @@ def __init__(
48774891
# Difference from original:
48784892
# uses Dynamic cache from legacy cache instead of HybridCache
48794893
# calculate causal mask from multimodal
4880-
model.__orig_forward = model.forward
48814894

48824895
def forward(
48834896
self, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds, use_cache=True
@@ -4913,31 +4926,40 @@ def forward(
49134926
result["past_key_values"] = upd_pkv.to_legacy_cache()
49144927
return result
49154928

4916-
model.forward = types.MethodType(forward, model)
4929+
if is_transformers_version("<", "4.53.0"):
4930+
model.__orig_forward = model.forward
4931+
model.forward = types.MethodType(forward, model)
4932+
49174933
super().__init__(config, model, model_kwargs)
49184934

49194935
def __enter__(self):
49204936
super().__enter__()
49214937

4922-
if hasattr(self._model, "_update_causal_mask_mm"):
4923-
self._model._orig_update_causual_mask_mm = self._model._update_causal_mask_mm
4938+
if is_transformers_version("<", "4.52.0"):
49244939
self._model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, self._model)
4925-
elif hasattr(self._model, "model") and hasattr(self._model.model, "_update_causal_mask_mm"):
4926-
self._model.model._orig_update_causual_mask_mm = self._model.model._update_causal_mask_mm
4927-
self._model.model._update_causal_mask_mm = types.MethodType(
4928-
_gemma3_mm_update_causal_mask, self._model.model
4929-
)
4940+
elif (
4941+
is_transformers_version("<", "4.53.0")
4942+
and hasattr(self._model, "model")
4943+
and hasattr(self._model.model, "_update_causal_mask")
4944+
):
4945+
self._model.model._orig_update_causual_mask = self._model.model._update_causal_mask
4946+
self._model.model._update_causal_mask = types.MethodType(_gemma3_mm_update_causal_mask, self._model.model)
49304947

49314948
def __exit__(self, exc_type, exc_value, traceback):
49324949
super().__exit__(exc_type, exc_value, traceback)
4933-
self._model.forward = self._model.__orig_forward
49344950

4935-
if hasattr(self._model, "_orig_update_causual_mask_mm"):
4936-
self._model._update_causal_mask_mm = self._model._orig_update_causal_mask_mm
4937-
del self._model._orig_update_causal_mask_mm
4938-
elif hasattr(self._model, "model") and hasattr(self._model.model, "_orig_update_causual_mask_mm"):
4939-
self._model.model._update_causal_mask_mm = self._model.model._orig_update_causual_mask_mm
4940-
del self._model.model._orig_update_causual_mask_mm
4951+
if is_transformers_version("<", "4.53.0"):
4952+
self._model.forward = self._model.__orig_forward
4953+
4954+
if is_transformers_version("<", "4.52"):
4955+
del self._update_causal_mask_mm
4956+
elif (
4957+
is_transformers_version("<", "4.53.0")
4958+
and hasattr(self._model, "model")
4959+
and hasattr(self._model.model, "_orig_update_causual_mask")
4960+
):
4961+
self._model.model._update_causal_mask = self._model.model._orig_update_causual_mask
4962+
del self._model.model._orig_update_causual_mask
49414963

49424964

49434965
class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher):

tests/openvino/test_modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,7 +2432,7 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
24322432
SUPPORT_AUDIO = []
24332433

24342434
if is_transformers_version(">=", "4.40.0"):
2435-
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
2435+
SUPPORTED_ARCHITECTURES += ["llava_next", "llava_next_mistral", "nanollava"]
24362436

24372437
if is_transformers_version(">=", "4.42.0"):
24382438
SUPPORTED_ARCHITECTURES += ["llava_next_video"]
@@ -2467,6 +2467,7 @@ def get_transformer_model_class(self, model_arch):
24672467
if is_transformers_version(">=", "4.46") and model_arch in [
24682468
"llava",
24692469
"llava_next",
2470+
"llava_next_mistral",
24702471
"qwen2_vl",
24712472
"qwen2_5_vl",
24722473
"got_ocr2",
@@ -2486,7 +2487,7 @@ def get_transformer_model_class(self, model_arch):
24862487
from transformers import LlavaForConditionalGeneration
24872488

24882489
return LlavaForConditionalGeneration
2489-
if model_arch == "llava_next":
2490+
if model_arch.startswith("llava_next"):
24902491
from transformers import LlavaNextForConditionalGeneration
24912492

24922493
return LlavaNextForConditionalGeneration
@@ -2667,7 +2668,7 @@ def test_compare_to_transformers(self, model_arch):
26672668

26682669
gc.collect()
26692670

2670-
@parameterized.expand(["llava", "llava_next", "llava_next_video"])
2671+
@parameterized.expand(["llava", "llava_next", "llava_next_video", "llava_next_mistral"])
26712672
@unittest.skipIf(
26722673
is_transformers_version("<", "4.45.0"), reason="New preprocessing available only in transformers >= 4.45"
26732674
)

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
"llama4": "katuni4ka/tiny-random-llama-4-8E",
102102
"llava": "katuni4ka/tiny-random-llava",
103103
"llava_next": "katuni4ka/tiny-random-llava-next",
104+
"llava_next_mistral": "optimum-internal-testing/tiny-random-llava-next-mistral",
104105
"llava_next_video": "katuni4ka/tiny-random-llava-next-video",
105106
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
106107
"opt": "hf-internal-testing/tiny-random-OPTModel",

0 commit comments

Comments
 (0)