2727from transformers .models .speecht5 .modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
2828from transformers .utils import is_tf_available
2929
30- from optimum .exporters .onnx .base import OnnxConfig
30+ from optimum .exporters .onnx .base import ConfigBehavior , OnnxConfig
3131from optimum .exporters .onnx .model_patcher import (
3232 UNSUPPORTED_OPS_PATCHING_SPEC ,
3333 DecoderModelPatcher ,
@@ -327,7 +327,11 @@ def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
327327 mask = sdpa_mask_without_vmap (* args , allow_is_causal_skip = False , ** kwargs )
328328 # we use torch.finfo(torch.float16).min instead torch.finfo(dtype).min to avoid an overflow but not
329329 # sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
330- mask = torch .where (mask , torch .tensor (0.0 , device = mask .device , dtype = dtype ), torch .finfo (torch .float16 ).min )
330+ mask = torch .where (
331+ mask ,
332+ torch .tensor (0.0 , device = mask .device , dtype = dtype ),
333+ torch .tensor (torch .finfo (torch .float16 ).min , device = mask .device , dtype = dtype ),
334+ )
331335 return mask
332336
333337
@@ -4711,52 +4715,77 @@ def __exit__(self, exc_type, exc_value, traceback):
47114715 layer .attn ._attn = layer .attn ._orig_attn
47124716
47134717
4714- class StatefulSeq2SeqDecoderPatcher (Seq2SeqModelPatcher ):
4718+ class OVSeq2SeqModelPatcher (Seq2SeqModelPatcher ):
47154719 def __init__ (
47164720 self ,
47174721 config : "OnnxConfig" ,
47184722 model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
47194723 model_kwargs : Optional [Dict [str , Any ]] = None ,
47204724 ):
4721- model .__orig_forward = model .forward
4725+ if getattr (config , "stateful" , False ) and config ._behavior == ConfigBehavior .DECODER :
4726+ model .__orig_forward = model .forward
47224727
4723- @functools .wraps (model .__orig_forward )
4724- def patched_forward (* args , ** kwargs ):
4725- from transformers .cache_utils import EncoderDecoderCache
4728+ @functools .wraps (model .__orig_forward )
4729+ def patched_forward (* args , ** kwargs ):
4730+ from transformers .cache_utils import EncoderDecoderCache
4731+
4732+ signature = inspect .signature (self .orig_forward )
4733+ args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = self .model_kwargs )
4734+
4735+ return_legacy_cache = False
4736+ pkv_in_args = False
4737+ legacy_pkv = None
4738+ if "past_key_values" in kwargs :
4739+ legacy_pkv = kwargs .pop ("past_key_values" , None )
4740+ sign_names = list (signature .parameters .keys ())
4741+ pkv_argument_index = sign_names .index ("past_key_values" )
4742+ if legacy_pkv is None and len (args ) > pkv_argument_index :
4743+ legacy_pkv = args [pkv_argument_index ]
4744+ pkv_in_args = True
4745+ if legacy_pkv is not None :
4746+ if isinstance (legacy_pkv , EncoderDecoderCache ):
4747+ legacy_pkv = legacy_pkv .to_legacy_cache ()
4748+ only_self_cache = [cache_item [:2 ] for cache_item in legacy_pkv ]
4749+ pkv = EncoderDecoderCache .from_legacy_cache (only_self_cache )
4750+ return_legacy_cache = True
4751+ if not pkv_in_args :
4752+ kwargs ["past_key_values" ] = pkv
4753+ else :
4754+ args [pkv_argument_index ] = pkv
47264755
4727- signature = inspect .signature (self .orig_forward )
4728- args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = self .model_kwargs )
4756+ outputs = model .__orig_forward (* args , ** kwargs )
4757+ if return_legacy_cache :
4758+ outputs .past_key_values = outputs .past_key_values .to_legacy_cache ()
47294759
4730- return_legacy_cache = False
4731- pkv_in_args = False
4732- legacy_pkv = None
4733- if "past_key_values" in kwargs :
4734- legacy_pkv = kwargs .pop ("past_key_values" , None )
4735- sign_names = list (signature .parameters .keys ())
4736- pkv_argument_index = sign_names .index ("past_key_values" )
4737- if legacy_pkv is None and len (args ) > pkv_argument_index :
4738- legacy_pkv = args [pkv_argument_index ]
4739- pkv_in_args = True
4740- if legacy_pkv is not None :
4741- if isinstance (legacy_pkv , EncoderDecoderCache ):
4742- legacy_pkv = legacy_pkv .to_legacy_cache ()
4743- only_self_cache = [cache_item [:2 ] for cache_item in legacy_pkv ]
4744- pkv = EncoderDecoderCache .from_legacy_cache (only_self_cache )
4745- return_legacy_cache = True
4746- if not pkv_in_args :
4747- kwargs ["past_key_values" ] = pkv
4748- else :
4749- args [pkv_argument_index ] = pkv
4760+ return outputs
47504761
4751- outputs = model .__orig_forward (* args , ** kwargs )
4752- if return_legacy_cache :
4753- outputs .past_key_values = outputs .past_key_values .to_legacy_cache ()
4762+ model .forward = patched_forward
47544763
4755- return outputs
4764+ super ().__init__ (config , model , model_kwargs )
4765+
4766+ def __enter__ (self ):
4767+ super ().__enter__ ()
47564768
4757- model .forward = patched_forward
4769+ if is_transformers_version (">=" , "4.53.0" ):
4770+ # for OpenVINO, we use torch.finfo(torch.float16).min instead of torch.finfo(dtype).min
4771+ # Although I'm not sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
4772+ ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , eager_mask_without_vmap )
47584773
4759- super ().__init__ (config , model , model_kwargs )
4774+ # for non-stateful decoder models, we use eager mask without vmap for sdpa as well
4775+ # to avoid a nan output issue in OpenVINO that only happens in case of non-stateful models
4776+ if not getattr (self .real_config , "stateful" , False ):
4777+ logger .warning (
4778+ "Exporting a non-stateful decoder model currently results in a nan output in OpenVINO. "
4779+ "There might be a performance impact due to the use of eager mask (floats) instead of sdpa mask (bools). "
4780+ )
4781+ ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , eager_mask_without_vmap )
4782+
4783+ def __exit__ (self , exc_type , exc_value , traceback ):
4784+ super ().__exit__ (exc_type , exc_value , traceback )
4785+
4786+ if is_transformers_version (">=" , "4.53.0" ):
4787+ ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , sdpa_mask )
4788+ ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , eager_mask )
47604789
47614790
47624791class SanaTextEncoderModelPatcher (ModelPatcher ):
@@ -5376,7 +5405,7 @@ def modulewise_unpatch(model, module_cls):
53765405 modulewise_unpatch (module , module_cls )
53775406
53785407
5379- class BlenderbotModelPatcher (Seq2SeqModelPatcher ):
5408+ class BlenderbotModelPatcher (OVSeq2SeqModelPatcher ):
53805409 def __enter__ (self ):
53815410 super ().__enter__ ()
53825411 if is_transformers_version (">=" , "4.49.0" ):
@@ -5392,7 +5421,7 @@ def __exit__(self, exc_type, exc_value, traceback):
53925421 modulewise_unpatch (self ._model , BlenderbotAttention )
53935422
53945423
5395- class BlenderbotSmallModelPatcher (Seq2SeqModelPatcher ):
5424+ class BlenderbotSmallModelPatcher (OVSeq2SeqModelPatcher ):
53965425 def __enter__ (self ):
53975426 super ().__enter__ ()
53985427 if is_transformers_version (">=" , "4.49.0" ):
@@ -5408,15 +5437,7 @@ def __exit__(self, exc_type, exc_value, traceback):
54085437 modulewise_unpatch (self ._model , BlenderbotSmallAttention )
54095438
54105439
5411- class BlenderbotStatefulSeq2SeqDecoderPatcher (StatefulSeq2SeqDecoderPatcher , BlenderbotModelPatcher ):
5412- pass
5413-
5414-
5415- class BlenderbotSmallStatefulSeq2SeqDecoderPatcher (StatefulSeq2SeqDecoderPatcher , BlenderbotSmallModelPatcher ):
5416- pass
5417-
5418-
5419- class PegasusModelPatcher (Seq2SeqModelPatcher ):
5440+ class PegasusModelPatcher (OVSeq2SeqModelPatcher ):
54205441 def __enter__ (self ):
54215442 super ().__enter__ ()
54225443 if is_transformers_version (">=" , "4.49.0" ):
@@ -5495,11 +5516,7 @@ def __exit__(self, exc_type, exc_value, traceback):
54955516 modulewise_unpatch (self ._model , Qwen2MoeSparseMoeBlock )
54965517
54975518
5498- class PegasusStatefulSeq2SeqDecoderPatcher (StatefulSeq2SeqDecoderPatcher , PegasusModelPatcher ):
5499- pass
5500-
5501-
5502- class MarianModelPatcher (Seq2SeqModelPatcher ):
5519+ class MarianModelPatcher (OVSeq2SeqModelPatcher ):
55035520 def __enter__ (self ):
55045521 super ().__enter__ ()
55055522 if is_transformers_version (">=" , "4.49.0" ):
@@ -5515,10 +5532,6 @@ def __exit__(self, exc_type, exc_value, traceback):
55155532 modulewise_unpatch (self ._model , MarianAttention )
55165533
55175534
5518- class MarianStatefulSeq2SeqDecoderPatcher (StatefulSeq2SeqDecoderPatcher , MarianModelPatcher ):
5519- pass
5520-
5521-
55225535# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L698
55235536# this is a patch to avoid PyTorch FE issue
55245537# with the same tensor names on input and intermediate tensor for speaker_embeddings
0 commit comments