6565    MULTI_MODAL_TEXT_GENERATION_MODELS ,
6666    OV_XML_FILE_NAME ,
6767    _get_input_info ,
68+     _get_dynamic_shapes_info ,
69+     _normalize_dummy_inputs ,
6870    _get_open_clip_submodels_fn_and_export_configs ,
6971    allow_skip_tracing_check ,
7072    clear_class_registry ,
@@ -422,7 +424,84 @@ def export_pytorch(
422424            # To handle it, additional wrapper on patcher forward applied. 
423425            # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False 
424426            patcher  =  config .patch_model_for_export (model , model_kwargs = model_kwargs )
425-             patched_forward  =  patcher .patched_forward 
427+             #patched_forward = patcher.orig_forward 
428+             import  inspect 
429+             from  optimum .exporters .onnx .model_patcher  import  override_arguments 
430+ 
431+             if  is_transformers_version (">=" , "4.48" ):
432+                 from  transformers .cache_utils  import  DynamicCache , EncoderDecoderCache 
433+ 
434+             @functools .wraps (patcher .orig_forward ) 
435+             def  patched_forward (* args , ** kwargs ):
436+                 signature  =  inspect .signature (patcher .orig_forward )
437+                 args , kwargs  =  override_arguments (args , kwargs , signature , model_kwargs = patcher .model_kwargs )
438+ 
439+                 if  is_transformers_version (">=" , "4.48" ):
440+                     if  "past_key_values"  in  signature .parameters :
441+                         pkv_index  =  list (signature .parameters .keys ()).index ("past_key_values" )
442+     
443+                         if  (
444+                             pkv_index  <  len (args )  # pkv is in args 
445+                             and  isinstance (args [pkv_index ], (list , tuple ))
446+                             and  isinstance (args [pkv_index ][0 ], (list , tuple ))
447+                         ):
448+                             if  len (args [pkv_index ][0 ]) ==  2 :
449+                                 args [pkv_index ] =  DynamicCache .from_legacy_cache (args [pkv_index ])
450+                             elif  len (args [pkv_index ][0 ]) ==  4 :
451+                                 args [pkv_index ] =  EncoderDecoderCache .from_legacy_cache (args [pkv_index ])
452+                             else :
453+                                 raise  ValueError (
454+                                     f"past_key_values should have either 2 or 4 elements, but it has { len (args [pkv_index ][0 ])}  
455+                                 )
456+                         elif  (
457+                             "past_key_values"  in  kwargs   # pkv is in kwargs 
458+                             and  isinstance (kwargs ["past_key_values" ], (list , tuple ))
459+                             and  isinstance (kwargs ["past_key_values" ][0 ], (list , tuple ))
460+                         ):
461+                             if  len (kwargs ["past_key_values" ][0 ]) ==  2 :
462+                                 kwargs ["past_key_values" ] =  DynamicCache .from_legacy_cache (kwargs ["past_key_values" ])
463+                             elif  len (kwargs ["past_key_values" ][0 ]) ==  4 :
464+                                 kwargs ["past_key_values" ] =  EncoderDecoderCache .from_legacy_cache (
465+                                     kwargs ["past_key_values" ]
466+                                 )
467+                             else :
468+                                 raise  ValueError (
469+                                     f"past_key_values should have either 2 or 4 elements, but it has { len (kwargs ['past_key_values' ][0 ])}  
470+                                 )
471+ 
472+                 outputs  =  patcher .orig_forward (* args , ** kwargs )
473+ 
474+                 # This code block handles different cases of the filterd_outputs input to align it with the expected 
475+                 # format of outputs. It is common for the output type of a model to vary, such as tensor, list, 
476+                 # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that 
477+                 # contains the output names of the model. In the case of Timm classification models, the output 
478+                 # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config 
479+                 # match the outputs in order. 
480+                 filterd_outputs  =  {}
481+                 if  isinstance (outputs , dict ):
482+                     for  name , value  in  outputs .items ():
483+                         filterd_outputs [name ] =  value 
484+                 elif  isinstance (outputs , (list , tuple )):
485+                     outputs_list  =  list (config .outputs .keys ())
486+                     filterd_outputs  =  dict (zip (outputs_list , outputs ))
487+                 else :
488+                     if  len (config .outputs ) >  1 :
489+                         num_outputs  =  len (config .outputs )
490+                         outputs_str  =  ", " .join (config .outputs .keys ())
491+                         raise  ValueError (
492+                             f"config.outputs should have only one outputs, but it has { num_outputs } { outputs_str }  
493+                         )
494+                     else :
495+                         name  =  list (config .outputs .keys ())[0 ]
496+                         filterd_outputs [name ] =  outputs 
497+                     name  =  list (config .outputs .keys ())[0 ]
498+                     filterd_outputs [name ] =  outputs 
499+ 
500+                 if  is_transformers_version (">=" , "4.48" ):
501+                     if  isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
502+                         filterd_outputs ["past_key_values" ] =  outputs ["past_key_values" ].to_legacy_cache ()
503+ 
504+                 return  filterd_outputs 
426505
427506            @functools .wraps (patched_forward ) 
428507            def  ts_patched_forward (* args , ** kwargs ):
@@ -443,21 +522,43 @@ def ts_patched_forward(*args, **kwargs):
443522                ts_decoder_kwargs ["trace_kwargs" ] =  {"check_trace" : False }
444523
445524            with  patcher :
446-                 if  patch_16bit_model :
447-                     from  openvino .frontend .pytorch .patch_model  import  __make_16bit_traceable 
448- 
449-                     __make_16bit_traceable (model )
525+                 use_export  =  True 
450526                check_dummy_inputs_are_allowed (model , dummy_inputs )
451527                input_info  =  _get_input_info (model , config , dummy_inputs )
452-                 ts_decoder  =  TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
453-                 ov_model  =  convert_model (
454-                     ts_decoder ,
455-                     example_input = dummy_inputs ,
456-                     input = [(item .shape , item .type ) for  item  in  input_info ],
457-                 )
528+                 if  use_export :
529+                     if  hasattr (torch .ops , "_prepare_4d_causal_attention_mask_for_sdpa" ):
530+                         # patch_everywhere breaks torch.ops namespace 
531+                         del  torch .ops ._prepare_4d_causal_attention_mask_for_sdpa 
532+                     dynamic_shapes  =  _get_dynamic_shapes_info (model , config , dummy_inputs )
533+                     _export_kwargs  =  {"args" : tuple (), "kwargs" : _normalize_dummy_inputs (dummy_inputs , model .dtype )}
534+                     _export_kwargs ["dynamic_shapes" ] =  dynamic_shapes 
535+ 
536+                     try :
537+                         from  nncf .torch .dynamic_graph .patch_pytorch  import  disable_patching 
538+                         # nncf patching breaks export 
539+                         with  disable_patching ():
540+                             ep  =  torch .export .export_for_training (model , ** _export_kwargs )
541+                     except  ImportError :
542+                         ep  =  torch .export .export_for_training (model , ** _export_kwargs )
543+ 
544+                     ov_model  =  convert_model (ep )
545+                 else :
546+                     if  patch_16bit_model :
547+                         from  openvino .frontend .pytorch .patch_model  import  __make_16bit_traceable 
548+ 
549+                         __make_16bit_traceable (model )
550+ 
551+                     ts_decoder  =  TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
552+                     ov_model  =  convert_model (
553+                         ts_decoder ,
554+                         example_input = dummy_inputs ,
555+                         input = [(item .shape , item .type ) for  item  in  input_info ],
556+                     )
458557
459558        except  Exception  as  ex :
460-             logger .warning (f"Export model to OpenVINO directly failed with: \n { ex } \n Model will be exported to ONNX" )
559+             logger .warning (f"Export model to OpenVINO directly failed with: \n " , exc_info = ex )
560+             raise  ex 
561+             logger .warning ("\n Model will be exported to ONNX" )
461562
462563            if  stateful :
463564                # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly 
0 commit comments