6565 MULTI_MODAL_TEXT_GENERATION_MODELS ,
6666 OV_XML_FILE_NAME ,
6767 _get_input_info ,
68+ _get_dynamic_shapes_info ,
69+ _normalize_dummy_inputs ,
6870 _get_open_clip_submodels_fn_and_export_configs ,
6971 allow_skip_tracing_check ,
7072 clear_class_registry ,
@@ -422,7 +424,84 @@ def export_pytorch(
422424 # To handle it, additional wrapper on patcher forward applied.
423425 # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
424426 patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
425- patched_forward = patcher .patched_forward
427+ #patched_forward = patcher.orig_forward
428+ import inspect
429+ from optimum .exporters .onnx .model_patcher import override_arguments
430+
431+ if is_transformers_version (">=" , "4.48" ):
432+ from transformers .cache_utils import DynamicCache , EncoderDecoderCache
433+
434+ @functools .wraps (patcher .orig_forward )
435+ def patched_forward (* args , ** kwargs ):
436+ signature = inspect .signature (patcher .orig_forward )
437+ args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = patcher .model_kwargs )
438+
439+ if is_transformers_version (">=" , "4.48" ):
440+ if "past_key_values" in signature .parameters :
441+ pkv_index = list (signature .parameters .keys ()).index ("past_key_values" )
442+
443+ if (
444+ pkv_index < len (args ) # pkv is in args
445+ and isinstance (args [pkv_index ], (list , tuple ))
446+ and isinstance (args [pkv_index ][0 ], (list , tuple ))
447+ ):
448+ if len (args [pkv_index ][0 ]) == 2 :
449+ args [pkv_index ] = DynamicCache .from_legacy_cache (args [pkv_index ])
450+ elif len (args [pkv_index ][0 ]) == 4 :
451+ args [pkv_index ] = EncoderDecoderCache .from_legacy_cache (args [pkv_index ])
452+ else :
453+ raise ValueError (
454+ f"past_key_values should have either 2 or 4 elements, but it has { len (args [pkv_index ][0 ])} elements"
455+ )
456+ elif (
457+ "past_key_values" in kwargs # pkv is in kwargs
458+ and isinstance (kwargs ["past_key_values" ], (list , tuple ))
459+ and isinstance (kwargs ["past_key_values" ][0 ], (list , tuple ))
460+ ):
461+ if len (kwargs ["past_key_values" ][0 ]) == 2 :
462+ kwargs ["past_key_values" ] = DynamicCache .from_legacy_cache (kwargs ["past_key_values" ])
463+ elif len (kwargs ["past_key_values" ][0 ]) == 4 :
464+ kwargs ["past_key_values" ] = EncoderDecoderCache .from_legacy_cache (
465+ kwargs ["past_key_values" ]
466+ )
467+ else :
468+ raise ValueError (
469+ f"past_key_values should have either 2 or 4 elements, but it has { len (kwargs ['past_key_values' ][0 ])} elements"
470+ )
471+
472+ outputs = patcher .orig_forward (* args , ** kwargs )
473+
474+ # This code block handles different cases of the filterd_outputs input to align it with the expected
475+ # format of outputs. It is common for the output type of a model to vary, such as tensor, list,
476+ # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
477+ # contains the output names of the model. In the case of Timm classification models, the output
478+ # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
479+ # match the outputs in order.
480+ filterd_outputs = {}
481+ if isinstance (outputs , dict ):
482+ for name , value in outputs .items ():
483+ filterd_outputs [name ] = value
484+ elif isinstance (outputs , (list , tuple )):
485+ outputs_list = list (config .outputs .keys ())
486+ filterd_outputs = dict (zip (outputs_list , outputs ))
487+ else :
488+ if len (config .outputs ) > 1 :
489+ num_outputs = len (config .outputs )
490+ outputs_str = ", " .join (config .outputs .keys ())
491+ raise ValueError (
492+ f"config.outputs should have only one outputs, but it has { num_outputs } keys: { outputs_str } "
493+ )
494+ else :
495+ name = list (config .outputs .keys ())[0 ]
496+ filterd_outputs [name ] = outputs
497+ name = list (config .outputs .keys ())[0 ]
498+ filterd_outputs [name ] = outputs
499+
500+ if is_transformers_version (">=" , "4.48" ):
501+ if isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
502+ filterd_outputs ["past_key_values" ] = outputs ["past_key_values" ].to_legacy_cache ()
503+
504+ return filterd_outputs
426505
427506 @functools .wraps (patched_forward )
428507 def ts_patched_forward (* args , ** kwargs ):
@@ -443,21 +522,43 @@ def ts_patched_forward(*args, **kwargs):
443522 ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
444523
445524 with patcher :
446- if patch_16bit_model :
447- from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
448-
449- __make_16bit_traceable (model )
525+ use_export = True
450526 check_dummy_inputs_are_allowed (model , dummy_inputs )
451527 input_info = _get_input_info (model , config , dummy_inputs )
452- ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
453- ov_model = convert_model (
454- ts_decoder ,
455- example_input = dummy_inputs ,
456- input = [(item .shape , item .type ) for item in input_info ],
457- )
528+ if use_export :
529+ if hasattr (torch .ops , "_prepare_4d_causal_attention_mask_for_sdpa" ):
530+ # patch_everywhere breaks torch.ops namespace
531+ del torch .ops ._prepare_4d_causal_attention_mask_for_sdpa
532+ dynamic_shapes = _get_dynamic_shapes_info (model , config , dummy_inputs )
533+ _export_kwargs = {"args" : tuple (), "kwargs" : _normalize_dummy_inputs (dummy_inputs , model .dtype )}
534+ _export_kwargs ["dynamic_shapes" ] = dynamic_shapes
535+
536+ try :
537+ from nncf .torch .dynamic_graph .patch_pytorch import disable_patching
538+ # nncf patching breaks export
539+ with disable_patching ():
540+ ep = torch .export .export_for_training (model , ** _export_kwargs )
541+ except ImportError :
542+ ep = torch .export .export_for_training (model , ** _export_kwargs )
543+
544+ ov_model = convert_model (ep )
545+ else :
546+ if patch_16bit_model :
547+ from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
548+
549+ __make_16bit_traceable (model )
550+
551+ ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
552+ ov_model = convert_model (
553+ ts_decoder ,
554+ example_input = dummy_inputs ,
555+ input = [(item .shape , item .type ) for item in input_info ],
556+ )
458557
459558 except Exception as ex :
460- logger .warning (f"Export model to OpenVINO directly failed with: \n { ex } .\n Model will be exported to ONNX" )
559+ logger .warning (f"Export model to OpenVINO directly failed with: \n " , exc_info = ex )
560+ raise ex
561+ logger .warning ("\n Model will be exported to ONNX" )
461562
462563 if stateful :
463564 # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
0 commit comments