6767 MULTI_MODAL_TEXT_GENERATION_MODELS ,
6868 OV_XML_FILE_NAME ,
6969 _get_input_info ,
70+ _get_dynamic_shapes_info ,
71+ _normalize_dummy_inputs ,
7072 _get_open_clip_submodels_fn_and_export_configs ,
7173 allow_skip_tracing_check ,
7274 clear_class_registry ,
@@ -425,6 +427,7 @@ def export_pytorch(
425427 patched_forward = patcher .patched_forward
426428 dummy_input_keys = list (dummy_inputs .keys ())
427429
430+ < << << << HEAD
428431 @functools .wraps (patched_forward )
429432 def ts_patched_forward (* args , ** kwargs ):
430433 ordered_example_inputs = [
@@ -442,14 +445,158 @@ def ts_patched_forward(*args, **kwargs):
442445 kwargs [input_name ] = input_dict
443446 outputs = patched_forward (** kwargs )
444447 return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
448+ == == == =
449+ try :
450+ # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
451+ # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
452+ # To handle it, additional wrapper on patcher forward applied.
453+ # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
454+ patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
455+ #patched_forward = patcher.orig_forward
456+ import inspect
457+ from optimum .exporters .onnx .model_patcher import override_arguments
458+
459+ if is_transformers_version (">=" , "4.48" ):
460+ from transformers .cache_utils import DynamicCache , EncoderDecoderCache
461+
462+ @functools .wraps (patcher .orig_forward )
463+ def patched_forward (* args , ** kwargs ):
464+ signature = inspect .signature (patcher .orig_forward )
465+ args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = patcher .model_kwargs )
466+
467+ if is_transformers_version (">=" , "4.48" ):
468+ if "past_key_values" in signature .parameters :
469+ pkv_index = list (signature .parameters .keys ()).index ("past_key_values" )
470+
471+ if (
472+ pkv_index < len (args ) # pkv is in args
473+ and isinstance (args [pkv_index ], (list , tuple ))
474+ and isinstance (args [pkv_index ][0 ], (list , tuple ))
475+ ):
476+ if len (args [pkv_index ][0 ]) == 2 :
477+ args [pkv_index ] = DynamicCache .from_legacy_cache (args [pkv_index ])
478+ elif len (args [pkv_index ][0 ]) == 4 :
479+ args [pkv_index ] = EncoderDecoderCache .from_legacy_cache (args [pkv_index ])
480+ else :
481+ raise ValueError (
482+ f"past_key_values should have either 2 or 4 elements, but it has { len (args [pkv_index ][0 ])} elements"
483+ )
484+ elif (
485+ "past_key_values" in kwargs # pkv is in kwargs
486+ and isinstance (kwargs ["past_key_values" ], (list , tuple ))
487+ and isinstance (kwargs ["past_key_values" ][0 ], (list , tuple ))
488+ ):
489+ if len (kwargs ["past_key_values" ][0 ]) == 2 :
490+ kwargs ["past_key_values" ] = DynamicCache .from_legacy_cache (kwargs ["past_key_values" ])
491+ elif len (kwargs ["past_key_values" ][0 ]) == 4 :
492+ kwargs ["past_key_values" ] = EncoderDecoderCache .from_legacy_cache (
493+ kwargs ["past_key_values" ]
494+ )
495+ else :
496+ raise ValueError (
497+ f"past_key_values should have either 2 or 4 elements, but it has { len (kwargs ['past_key_values' ][0 ])} elements"
498+ )
499+
500+ outputs = patcher .orig_forward (* args , ** kwargs )
501+
502+ # This code block handles different cases of the filterd_outputs input to align it with the expected
503+ # format of outputs. It is common for the output type of a model to vary, such as tensor, list,
504+ # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
505+ # contains the output names of the model. In the case of Timm classification models, the output
506+ # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
507+ # match the outputs in order.
508+ filterd_outputs = {}
509+ if isinstance (outputs , dict ):
510+ for name , value in outputs .items ():
511+ filterd_outputs [name ] = value
512+ elif isinstance (outputs , (list , tuple )):
513+ outputs_list = list (config .outputs .keys ())
514+ filterd_outputs = dict (zip (outputs_list , outputs ))
515+ else :
516+ if len (config .outputs ) > 1 :
517+ num_outputs = len (config .outputs )
518+ outputs_str = ", " .join (config .outputs .keys ())
519+ raise ValueError (
520+ f"config.outputs should have only one outputs, but it has { num_outputs } keys: { outputs_str } "
521+ )
522+ else :
523+ name = list (config .outputs .keys ())[0 ]
524+ filterd_outputs [name ] = outputs
525+ name = list (config .outputs .keys ())[0 ]
526+ filterd_outputs [name ] = outputs
527+
528+ if is_transformers_version (">=" , "4.48" ):
529+ if isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
530+ filterd_outputs ["past_key_values" ] = outputs ["past_key_values" ].to_legacy_cache ()
531+
532+ return filterd_outputs
533+ >> >> >> > cfde44f ([POC ] Use torch .export for converting )
445534
446535 patcher .patched_forward = ts_patched_forward
447536
537+ < << << << HEAD
448538 ts_decoder_kwargs = {}
449539 model_config = getattr (model , "config" , {})
450540 model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
451541 if allow_skip_tracing_check (library_name , model_type ):
452542 ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
543+ == == == =
544+ patcher .patched_forward = ts_patched_forward
545+
546+ ts_decoder_kwargs = {}
547+ model_config = getattr (model , "config" , {})
548+ model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
549+ if allow_skip_tracing_check (library_name , model_type ):
550+ ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
551+
552+ with patcher :
553+ use_export = True
554+ check_dummy_inputs_are_allowed (model , dummy_inputs )
555+ input_info = _get_input_info (model , config , dummy_inputs )
556+ if use_export :
557+ if hasattr (torch .ops , "_prepare_4d_causal_attention_mask_for_sdpa" ):
558+ # patch_everywhere breaks torch.ops namespace
559+ del torch .ops ._prepare_4d_causal_attention_mask_for_sdpa
560+ dynamic_shapes = _get_dynamic_shapes_info (model , config , dummy_inputs )
561+ _export_kwargs = {"args" : tuple (), "kwargs" : _normalize_dummy_inputs (dummy_inputs , model .dtype )}
562+ _export_kwargs ["dynamic_shapes" ] = dynamic_shapes
563+
564+ try :
565+ from nncf .torch .dynamic_graph .patch_pytorch import disable_patching
566+ # nncf patching breaks export
567+ with disable_patching ():
568+ ep = torch .export .export_for_training (model , ** _export_kwargs )
569+ except ImportError :
570+ ep = torch .export .export_for_training (model , ** _export_kwargs )
571+
572+ ov_model = convert_model (ep )
573+ else :
574+ if patch_16bit_model :
575+ from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
576+
577+ __make_16bit_traceable (model )
578+
579+ ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
580+ ov_model = convert_model (
581+ ts_decoder ,
582+ example_input = dummy_inputs ,
583+ input = [(item .shape , item .type ) for item in input_info ],
584+ )
585+
586+ except Exception as ex :
587+ logger .warning (f"Export model to OpenVINO directly failed with: \n " , exc_info = ex )
588+ raise ex
589+ logger .warning ("\n Model will be exported to ONNX" )
590+
591+ if stateful :
592+ # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
593+ # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
594+ logger .warning (
595+ "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
596+ "A stateless model will be exported instead. It may result in sub-optimal inference performance."
597+ "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
598+ )
599+ >> >> >> > cfde44f ([POC ] Use torch .export for converting )
453600
454601 with patcher :
455602 if patch_16bit_model :
0 commit comments