Skip to content

Commit ecc4bb1

Browse files
Fix and test stateless encoder decoders (#1423)
* fix * test other seq2seq stateless models
1 parent c81d05b commit ecc4bb1

File tree

3 files changed

+92
-77
lines changed

3 files changed

+92
-77
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@
8383
BaichuanModelPatcher,
8484
BlenderbotModelPatcher,
8585
BlenderbotSmallModelPatcher,
86-
BlenderbotSmallStatefulSeq2SeqDecoderPatcher,
87-
BlenderbotStatefulSeq2SeqDecoderPatcher,
8886
BloomModelPatcher,
8987
ChatGLMModelPatcher,
9088
CodeGenModelPatcher,
@@ -117,7 +115,6 @@
117115
MairaImageEmbeddingModelPatcher,
118116
MambaPatcher,
119117
MarianModelPatcher,
120-
MarianStatefulSeq2SeqDecoderPatcher,
121118
MiniCPM3Patcher,
122119
MiniCPMModelPatcher,
123120
MiniCPMVImageEmbeddingsModelPatcher,
@@ -126,9 +123,9 @@
126123
MixtralModelPatcher,
127124
MPTModelPatcher,
128125
OVDecoderModelPatcher,
126+
OVSeq2SeqModelPatcher,
129127
OVSpeechT5ModelPatcher,
130128
PegasusModelPatcher,
131-
PegasusStatefulSeq2SeqDecoderPatcher,
132129
PersimmonModelPatcher,
133130
Phi3ModelPatcher,
134131
Phi3VisionImageEmbeddingsPatcher,
@@ -144,7 +141,6 @@
144141
Qwen3MoeModelPatcher,
145142
QwenModelPatcher,
146143
SanaTextEncoderModelPatcher,
147-
StatefulSeq2SeqDecoderPatcher,
148144
XverseModelPatcher,
149145
)
150146

@@ -3738,9 +3734,7 @@ class WhisperOpenVINOConfig(WhisperOnnxConfig):
37383734
def patch_model_for_export(
37393735
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
37403736
) -> ModelPatcher:
3741-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
3742-
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
3743-
return super().patch_model_for_export(model, model_kwargs)
3737+
return OVSeq2SeqModelPatcher(self, model, model_kwargs=model_kwargs)
37443738

37453739
@property
37463740
def inputs(self):
@@ -3764,9 +3758,7 @@ class T5OpenVINOConfig(T5OnnxConfig):
37643758
def patch_model_for_export(
37653759
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
37663760
) -> ModelPatcher:
3767-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
3768-
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
3769-
return super().patch_model_for_export(model, model_kwargs)
3761+
return OVSeq2SeqModelPatcher(self, model, model_kwargs)
37703762

37713763
@property
37723764
def inputs(self):
@@ -3812,9 +3804,7 @@ class BartOpenVINOConfig(BartOnnxConfig):
38123804
def patch_model_for_export(
38133805
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
38143806
) -> ModelPatcher:
3815-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
3816-
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
3817-
return super().patch_model_for_export(model, model_kwargs)
3807+
return OVSeq2SeqModelPatcher(self, model, model_kwargs)
38183808

38193809
@property
38203810
def inputs(self):
@@ -4056,8 +4046,6 @@ class BlenderbotOpenVINOConfig(BlenderbotOnnxConfig):
40564046
def patch_model_for_export(
40574047
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
40584048
) -> "ModelPatcher":
4059-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
4060-
return BlenderbotStatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
40614049
return BlenderbotModelPatcher(self, model, model_kwargs=model_kwargs)
40624050

40634051
@property
@@ -4084,8 +4072,6 @@ class BlenderbotSmallOpenVINOConfig(BlenderbotSmallOnnxConfig):
40844072
def patch_model_for_export(
40854073
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
40864074
) -> "ModelPatcher":
4087-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
4088-
return BlenderbotSmallStatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
40894075
return BlenderbotSmallModelPatcher(self, model, model_kwargs=model_kwargs)
40904076

40914077
@property
@@ -4112,8 +4098,6 @@ class PegasusOpenVINOConfig(PegasusOnnxConfig):
41124098
def patch_model_for_export(
41134099
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
41144100
) -> "ModelPatcher":
4115-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
4116-
return PegasusStatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
41174101
return PegasusModelPatcher(self, model, model_kwargs=model_kwargs)
41184102

41194103
@property
@@ -4140,8 +4124,6 @@ class MarianOpenVINOConfig(MarianOnnxConfig):
41404124
def patch_model_for_export(
41414125
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
41424126
) -> "ModelPatcher":
4143-
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
4144-
return MarianStatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
41454127
return MarianModelPatcher(self, model, model_kwargs=model_kwargs)
41464128

41474129
@property

optimum/exporters/openvino/model_patcher.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
2828
from transformers.utils import is_tf_available
2929

30-
from optimum.exporters.onnx.base import OnnxConfig
30+
from optimum.exporters.onnx.base import ConfigBehavior, OnnxConfig
3131
from optimum.exporters.onnx.model_patcher import (
3232
UNSUPPORTED_OPS_PATCHING_SPEC,
3333
DecoderModelPatcher,
@@ -327,7 +327,11 @@ def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
327327
mask = sdpa_mask_without_vmap(*args, allow_is_causal_skip=False, **kwargs)
328328
# we use torch.finfo(torch.float16).min instead torch.finfo(dtype).min to avoid an overflow but not
329329
# sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
330-
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), torch.finfo(torch.float16).min)
330+
mask = torch.where(
331+
mask,
332+
torch.tensor(0.0, device=mask.device, dtype=dtype),
333+
torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
334+
)
331335
return mask
332336

333337

@@ -4711,52 +4715,77 @@ def __exit__(self, exc_type, exc_value, traceback):
47114715
layer.attn._attn = layer.attn._orig_attn
47124716

47134717

4714-
class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
4718+
class OVSeq2SeqModelPatcher(Seq2SeqModelPatcher):
47154719
def __init__(
47164720
self,
47174721
config: "OnnxConfig",
47184722
model: Union["PreTrainedModel", "TFPreTrainedModel"],
47194723
model_kwargs: Optional[Dict[str, Any]] = None,
47204724
):
4721-
model.__orig_forward = model.forward
4725+
if getattr(config, "stateful", False) and config._behavior == ConfigBehavior.DECODER:
4726+
model.__orig_forward = model.forward
47224727

4723-
@functools.wraps(model.__orig_forward)
4724-
def patched_forward(*args, **kwargs):
4725-
from transformers.cache_utils import EncoderDecoderCache
4728+
@functools.wraps(model.__orig_forward)
4729+
def patched_forward(*args, **kwargs):
4730+
from transformers.cache_utils import EncoderDecoderCache
4731+
4732+
signature = inspect.signature(self.orig_forward)
4733+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
4734+
4735+
return_legacy_cache = False
4736+
pkv_in_args = False
4737+
legacy_pkv = None
4738+
if "past_key_values" in kwargs:
4739+
legacy_pkv = kwargs.pop("past_key_values", None)
4740+
sign_names = list(signature.parameters.keys())
4741+
pkv_argument_index = sign_names.index("past_key_values")
4742+
if legacy_pkv is None and len(args) > pkv_argument_index:
4743+
legacy_pkv = args[pkv_argument_index]
4744+
pkv_in_args = True
4745+
if legacy_pkv is not None:
4746+
if isinstance(legacy_pkv, EncoderDecoderCache):
4747+
legacy_pkv = legacy_pkv.to_legacy_cache()
4748+
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
4749+
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
4750+
return_legacy_cache = True
4751+
if not pkv_in_args:
4752+
kwargs["past_key_values"] = pkv
4753+
else:
4754+
args[pkv_argument_index] = pkv
47264755

4727-
signature = inspect.signature(self.orig_forward)
4728-
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
4756+
outputs = model.__orig_forward(*args, **kwargs)
4757+
if return_legacy_cache:
4758+
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
47294759

4730-
return_legacy_cache = False
4731-
pkv_in_args = False
4732-
legacy_pkv = None
4733-
if "past_key_values" in kwargs:
4734-
legacy_pkv = kwargs.pop("past_key_values", None)
4735-
sign_names = list(signature.parameters.keys())
4736-
pkv_argument_index = sign_names.index("past_key_values")
4737-
if legacy_pkv is None and len(args) > pkv_argument_index:
4738-
legacy_pkv = args[pkv_argument_index]
4739-
pkv_in_args = True
4740-
if legacy_pkv is not None:
4741-
if isinstance(legacy_pkv, EncoderDecoderCache):
4742-
legacy_pkv = legacy_pkv.to_legacy_cache()
4743-
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
4744-
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
4745-
return_legacy_cache = True
4746-
if not pkv_in_args:
4747-
kwargs["past_key_values"] = pkv
4748-
else:
4749-
args[pkv_argument_index] = pkv
4760+
return outputs
47504761

4751-
outputs = model.__orig_forward(*args, **kwargs)
4752-
if return_legacy_cache:
4753-
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
4762+
model.forward = patched_forward
47544763

4755-
return outputs
4764+
super().__init__(config, model, model_kwargs)
4765+
4766+
def __enter__(self):
4767+
super().__enter__()
47564768

4757-
model.forward = patched_forward
4769+
if is_transformers_version(">=", "4.53.0"):
4770+
# for OpenVINO, we use torch.finfo(torch.float16).min instead of torch.finfo(dtype).min
4771+
# Although I'm not sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
4772+
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
47584773

4759-
super().__init__(config, model, model_kwargs)
4774+
# for non-stateful decoder models, we use eager mask without vmap for sdpa as well
4775+
# to avoid a nan output issue in OpenVINO that only happens in case of non-stateful models
4776+
if not getattr(self.real_config, "stateful", False):
4777+
logger.warning(
4778+
"Exporting a non-stateful decoder model currently results in a nan output in OpenVINO. "
4779+
"There might be a performance impact due to the use of eager mask (floats) instead of sdpa mask (bools). "
4780+
)
4781+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", eager_mask_without_vmap)
4782+
4783+
def __exit__(self, exc_type, exc_value, traceback):
4784+
super().__exit__(exc_type, exc_value, traceback)
4785+
4786+
if is_transformers_version(">=", "4.53.0"):
4787+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask)
4788+
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask)
47604789

47614790

47624791
class SanaTextEncoderModelPatcher(ModelPatcher):
@@ -5376,7 +5405,7 @@ def modulewise_unpatch(model, module_cls):
53765405
modulewise_unpatch(module, module_cls)
53775406

53785407

5379-
class BlenderbotModelPatcher(Seq2SeqModelPatcher):
5408+
class BlenderbotModelPatcher(OVSeq2SeqModelPatcher):
53805409
def __enter__(self):
53815410
super().__enter__()
53825411
if is_transformers_version(">=", "4.49.0"):
@@ -5392,7 +5421,7 @@ def __exit__(self, exc_type, exc_value, traceback):
53925421
modulewise_unpatch(self._model, BlenderbotAttention)
53935422

53945423

5395-
class BlenderbotSmallModelPatcher(Seq2SeqModelPatcher):
5424+
class BlenderbotSmallModelPatcher(OVSeq2SeqModelPatcher):
53965425
def __enter__(self):
53975426
super().__enter__()
53985427
if is_transformers_version(">=", "4.49.0"):
@@ -5408,15 +5437,7 @@ def __exit__(self, exc_type, exc_value, traceback):
54085437
modulewise_unpatch(self._model, BlenderbotSmallAttention)
54095438

54105439

5411-
class BlenderbotStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotModelPatcher):
5412-
pass
5413-
5414-
5415-
class BlenderbotSmallStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotSmallModelPatcher):
5416-
pass
5417-
5418-
5419-
class PegasusModelPatcher(Seq2SeqModelPatcher):
5440+
class PegasusModelPatcher(OVSeq2SeqModelPatcher):
54205441
def __enter__(self):
54215442
super().__enter__()
54225443
if is_transformers_version(">=", "4.49.0"):
@@ -5495,11 +5516,7 @@ def __exit__(self, exc_type, exc_value, traceback):
54955516
modulewise_unpatch(self._model, Qwen2MoeSparseMoeBlock)
54965517

54975518

5498-
class PegasusStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, PegasusModelPatcher):
5499-
pass
5500-
5501-
5502-
class MarianModelPatcher(Seq2SeqModelPatcher):
5519+
class MarianModelPatcher(OVSeq2SeqModelPatcher):
55035520
def __enter__(self):
55045521
super().__enter__()
55055522
if is_transformers_version(">=", "4.49.0"):
@@ -5515,10 +5532,6 @@ def __exit__(self, exc_type, exc_value, traceback):
55155532
modulewise_unpatch(self._model, MarianAttention)
55165533

55175534

5518-
class MarianStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, MarianModelPatcher):
5519-
pass
5520-
5521-
55225535
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L698
55235536
# this is a patch to avoid PyTorch FE issue
55245537
# with the same tensor names on input and intermediate tensor for speaker_embeddings

tests/openvino/test_modeling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,6 +1960,9 @@ def test_compare_to_transformers(self, model_arch):
19601960
model_id = MODEL_NAMES[model_arch]
19611961
set_seed(SEED)
19621962
ov_model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
1963+
ov_stateless_model = OVModelForSeq2SeqLM.from_pretrained(
1964+
model_id, export=True, use_cache=False, stateful=False, ov_config=F32_CONFIG
1965+
)
19631966
expected_stateful = is_transformers_version(">", "4.43") and model_arch in self.SUPPORT_STATEFUL
19641967
self.assertEqual(ov_model.decoder.stateful, expected_stateful)
19651968
self.assertEqual(model_has_state(ov_model.decoder.model), expected_stateful)
@@ -1977,6 +1980,7 @@ def test_compare_to_transformers(self, model_arch):
19771980
decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2
19781981
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}
19791982
ov_outputs = ov_model(**tokens, **decoder_inputs)
1983+
ov_stateless_outputs = ov_stateless_model(**tokens, **decoder_inputs)
19801984

19811985
self.assertTrue("logits" in ov_outputs)
19821986
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
@@ -1985,6 +1989,7 @@ def test_compare_to_transformers(self, model_arch):
19851989
transformers_outputs = transformers_model(**tokens, **decoder_inputs)
19861990
# Compare tensor outputs
19871991
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=5e-3))
1992+
self.assertTrue(torch.allclose(ov_stateless_outputs.logits, transformers_outputs.logits, atol=5e-3))
19881993
gen_config = GenerationConfig(
19891994
max_new_tokens=10,
19901995
min_new_tokens=10,
@@ -1997,8 +2002,11 @@ def test_compare_to_transformers(self, model_arch):
19972002
generated_tokens = transformers_model.generate(**tokens, generation_config=gen_config)
19982003
set_seed(SEED)
19992004
ov_generated_tokens = ov_model.generate(**tokens, generation_config=gen_config)
2005+
set_seed(SEED)
2006+
ov_stateless_generated_tokens = ov_stateless_model.generate(**tokens, generation_config=gen_config)
20002007

20012008
self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))
2009+
self.assertTrue(torch.equal(generated_tokens, ov_stateless_generated_tokens))
20022010

20032011
del transformers_model
20042012
del ov_model
@@ -2850,6 +2858,9 @@ def test_compare_to_transformers(self, model_arch):
28502858
model_id = MODEL_NAMES[model_arch]
28512859
transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
28522860
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
2861+
ov_model_stateless = OVModelForSpeechSeq2Seq.from_pretrained(
2862+
model_id, export=True, ov_config=F32_CONFIG, stateful=False
2863+
)
28532864
self.assertIsInstance(ov_model.config, PretrainedConfig)
28542865
# whisper cache class support implemented in 4.43
28552866
expected_stateful = is_transformers_version(">", "4.43")
@@ -2874,9 +2885,13 @@ def test_compare_to_transformers(self, model_arch):
28742885
decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}
28752886

28762887
ov_outputs = ov_model(**features, **decoder_inputs)
2888+
ov_stateless_outputs = ov_model_stateless(**features, **decoder_inputs)
28772889
self.assertIn("logits", ov_outputs)
28782890
# Compare tensor outputs
28792891
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))
2892+
self.assertTrue(
2893+
torch.allclose(torch.Tensor(ov_stateless_outputs.logits), transformers_outputs.logits, atol=1e-3)
2894+
)
28802895

28812896
generate_kwrgs = {}
28822897
if is_transformers_version(">=", "4.50"):
@@ -2894,8 +2909,13 @@ def test_compare_to_transformers(self, model_arch):
28942909
generated_tokens = transformers_model.generate(**pt_features, generation_config=gen_config, **generate_kwrgs)
28952910
set_seed(SEED)
28962911
ov_generated_tokens = ov_model.generate(**pt_features, generation_config=gen_config, **generate_kwrgs)
2912+
set_seed(SEED)
2913+
ov_stateless_generated_tokens = ov_model_stateless.generate(
2914+
**pt_features, generation_config=gen_config, **generate_kwrgs
2915+
)
28972916

28982917
self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))
2918+
self.assertTrue(torch.equal(generated_tokens, ov_stateless_generated_tokens))
28992919

29002920
del transformers_model
29012921
del ov_model

0 commit comments

Comments
 (0)