6767    MULTI_MODAL_TEXT_GENERATION_MODELS ,
6868    OV_XML_FILE_NAME ,
6969    _get_input_info ,
70-     _get_dynamic_shapes_info ,
71-     _normalize_dummy_inputs ,
7270    _get_open_clip_submodels_fn_and_export_configs ,
73-     get_model_dtype ,
7471    allow_skip_tracing_check ,
7572    clear_class_registry ,
7673    remove_none_from_dummy_inputs ,
@@ -428,7 +425,6 @@ def export_pytorch(
428425        patched_forward  =  patcher .patched_forward 
429426        dummy_input_keys  =  list (dummy_inputs .keys ())
430427
431- < << << <<  HEAD 
432428        @functools .wraps (patched_forward ) 
433429        def  ts_patched_forward (* args , ** kwargs ):
434430            ordered_example_inputs  =  [
@@ -446,158 +442,14 @@ def ts_patched_forward(*args, **kwargs):
446442                kwargs [input_name ] =  input_dict 
447443            outputs  =  patched_forward (** kwargs )
448444            return  tuple ([value  if  not  isinstance (value , list ) else  tuple (value ) for  value  in  outputs .values ()])
449- == == == = 
450-         try :
451-             # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching, 
452-             # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output 
453-             # To handle it, additional wrapper on patcher forward applied. 
454-             # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False 
455-             patcher  =  config .patch_model_for_export (model , model_kwargs = model_kwargs )
456-             #patched_forward = patcher.orig_forward 
457-             import  inspect 
458-             from  optimum .exporters .onnx .model_patcher  import  override_arguments 
459- 
460-             if  is_transformers_version (">=" , "4.48" ):
461-                 from  transformers .cache_utils  import  DynamicCache , EncoderDecoderCache 
462- 
463-             @functools .wraps (patcher .orig_forward ) 
464-             def  patched_forward (* args , ** kwargs ):
465-                 signature  =  inspect .signature (patcher .orig_forward )
466-                 args , kwargs  =  override_arguments (args , kwargs , signature , model_kwargs = patcher .model_kwargs )
467- 
468-                 if  is_transformers_version (">=" , "4.48" ):
469-                     if  "past_key_values"  in  signature .parameters :
470-                         pkv_index  =  list (signature .parameters .keys ()).index ("past_key_values" )
471-     
472-                         if  (
473-                             pkv_index  <  len (args )  # pkv is in args 
474-                             and  isinstance (args [pkv_index ], (list , tuple ))
475-                             and  isinstance (args [pkv_index ][0 ], (list , tuple ))
476-                         ):
477-                             if  len (args [pkv_index ][0 ]) ==  2 :
478-                                 args [pkv_index ] =  DynamicCache .from_legacy_cache (args [pkv_index ])
479-                             elif  len (args [pkv_index ][0 ]) ==  4 :
480-                                 args [pkv_index ] =  EncoderDecoderCache .from_legacy_cache (args [pkv_index ])
481-                             else :
482-                                 raise  ValueError (
483-                                     f"past_key_values should have either 2 or 4 elements, but it has { len (args [pkv_index ][0 ])}  
484-                                 )
485-                         elif  (
486-                             "past_key_values"  in  kwargs   # pkv is in kwargs 
487-                             and  isinstance (kwargs ["past_key_values" ], (list , tuple ))
488-                             and  isinstance (kwargs ["past_key_values" ][0 ], (list , tuple ))
489-                         ):
490-                             if  len (kwargs ["past_key_values" ][0 ]) ==  2 :
491-                                 kwargs ["past_key_values" ] =  DynamicCache .from_legacy_cache (kwargs ["past_key_values" ])
492-                             elif  len (kwargs ["past_key_values" ][0 ]) ==  4 :
493-                                 kwargs ["past_key_values" ] =  EncoderDecoderCache .from_legacy_cache (
494-                                     kwargs ["past_key_values" ]
495-                                 )
496-                             else :
497-                                 raise  ValueError (
498-                                     f"past_key_values should have either 2 or 4 elements, but it has { len (kwargs ['past_key_values' ][0 ])}  
499-                                 )
500- 
501-                 outputs  =  patcher .orig_forward (* args , ** kwargs )
502- 
503-                 # This code block handles different cases of the filterd_outputs input to align it with the expected 
504-                 # format of outputs. It is common for the output type of a model to vary, such as tensor, list, 
505-                 # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that 
506-                 # contains the output names of the model. In the case of Timm classification models, the output 
507-                 # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config 
508-                 # match the outputs in order. 
509-                 filterd_outputs  =  {}
510-                 if  isinstance (outputs , dict ):
511-                     for  name , value  in  outputs .items ():
512-                         filterd_outputs [name ] =  value 
513-                 elif  isinstance (outputs , (list , tuple )):
514-                     outputs_list  =  list (config .outputs .keys ())
515-                     filterd_outputs  =  dict (zip (outputs_list , outputs ))
516-                 else :
517-                     if  len (config .outputs ) >  1 :
518-                         num_outputs  =  len (config .outputs )
519-                         outputs_str  =  ", " .join (config .outputs .keys ())
520-                         raise  ValueError (
521-                             f"config.outputs should have only one outputs, but it has { num_outputs } { outputs_str }  
522-                         )
523-                     else :
524-                         name  =  list (config .outputs .keys ())[0 ]
525-                         filterd_outputs [name ] =  outputs 
526-                     name  =  list (config .outputs .keys ())[0 ]
527-                     filterd_outputs [name ] =  outputs 
528- 
529-                 if  is_transformers_version (">=" , "4.48" ):
530-                     if  isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
531-                         filterd_outputs ["past_key_values" ] =  outputs ["past_key_values" ].to_legacy_cache ()
532- 
533-                 return  filterd_outputs 
534- >> >> >> >  cfde44f  ([POC ] Use  torch .export  for  converting )
535445
536446        patcher .patched_forward  =  ts_patched_forward 
537447
538- < << << <<  HEAD 
539448        ts_decoder_kwargs  =  {}
540449        model_config  =  getattr (model , "config" , {})
541450        model_type  =  getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
542451        if  allow_skip_tracing_check (library_name , model_type ):
543452            ts_decoder_kwargs ["trace_kwargs" ] =  {"check_trace" : False }
544- == == == = 
545-             patcher .patched_forward  =  ts_patched_forward 
546- 
547-             ts_decoder_kwargs  =  {}
548-             model_config  =  getattr (model , "config" , {})
549-             model_type  =  getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
550-             if  allow_skip_tracing_check (library_name , model_type ):
551-                 ts_decoder_kwargs ["trace_kwargs" ] =  {"check_trace" : False }
552- 
553-             with  patcher :
554-                 use_export  =  True 
555-                 check_dummy_inputs_are_allowed (model , dummy_inputs )
556-                 input_info  =  _get_input_info (model , config , dummy_inputs )
557-                 if  use_export :
558-                     if  hasattr (torch .ops , "_prepare_4d_causal_attention_mask_for_sdpa" ):
559-                         # patch_everywhere breaks torch.ops namespace 
560-                         del  torch .ops ._prepare_4d_causal_attention_mask_for_sdpa 
561-                     dynamic_shapes  =  _get_dynamic_shapes_info (model , config , dummy_inputs )
562-                     _export_kwargs  =  {"args" : tuple (), "kwargs" : _normalize_dummy_inputs (dummy_inputs , get_model_dtype (model ))}
563-                     _export_kwargs ["dynamic_shapes" ] =  dynamic_shapes 
564- 
565-                     try :
566-                         from  nncf .torch .dynamic_graph .patch_pytorch  import  disable_patching 
567-                         # nncf patching breaks export 
568-                         with  disable_patching ():
569-                             ep  =  torch .export .export_for_training (model , ** _export_kwargs )
570-                     except  ImportError :
571-                         ep  =  torch .export .export_for_training (model , ** _export_kwargs )
572- 
573-                     ov_model  =  convert_model (ep )
574-                 else :
575-                     if  patch_16bit_model :
576-                         from  openvino .frontend .pytorch .patch_model  import  __make_16bit_traceable 
577- 
578-                         __make_16bit_traceable (model )
579- 
580-                     ts_decoder  =  TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
581-                     ov_model  =  convert_model (
582-                         ts_decoder ,
583-                         example_input = dummy_inputs ,
584-                         input = [(item .shape , item .type ) for  item  in  input_info ],
585-                     )
586- 
587-         except  Exception  as  ex :
588-             logger .warning (f"Export model to OpenVINO directly failed with: \n " , exc_info = ex )
589-             raise  ex 
590-             logger .warning ("\n Model will be exported to ONNX" )
591- 
592-             if  stateful :
593-                 # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly 
594-                 # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation 
595-                 logger .warning (
596-                     "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. " 
597-                     "A stateless model will be exported instead. It may result in sub-optimal inference performance." 
598-                     "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path." 
599-                 )
600- >> >> >> >  cfde44f  ([POC ] Use  torch .export  for  converting )
601453
602454        with  patcher :
603455            if  patch_16bit_model :
0 commit comments