@@ -417,83 +417,52 @@ def export_pytorch(
417417
418418 dummy_inputs = config .rename_ambiguous_inputs (dummy_inputs )
419419 dummy_inputs , dict_inputs = remove_none_from_dummy_inputs (dummy_inputs )
420-
421- try :
422- # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
423- # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
424- # To handle it, additional wrapper on patcher forward applied.
425- # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
426- patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
427- patched_forward = patcher .patched_forward
428- dummy_input_keys = list (dummy_inputs .keys ())
429-
430- @functools .wraps (patched_forward )
431- def ts_patched_forward (* args , ** kwargs ):
432- ordered_example_inputs = [
433- param for param in inspect .signature (patcher .orig_forward ).parameters if param in dummy_input_keys
434- ]
435- kwargs .update (zip (ordered_example_inputs , args ))
436- for i in range (len (dict_inputs )):
437- input_name , keys = dict_inputs [i ]
438- tuple_input = kwargs [input_name ]
439- input_dict = dict (zip (keys , tuple_input ))
440- kwargs [input_name ] = input_dict
441- outputs = patched_forward (** kwargs )
442- return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
443-
444- patcher .patched_forward = ts_patched_forward
445-
446- ts_decoder_kwargs = {}
447- model_config = getattr (model , "config" , {})
448- model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
449- if allow_skip_tracing_check (library_name , model_type ):
450- ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
451-
452- with patcher :
453- if patch_16bit_model :
454- from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
455-
456- __make_16bit_traceable (model )
457- check_dummy_inputs_are_allowed (model , dummy_inputs )
458- input_info = _get_input_info (model , config , dummy_inputs )
459- ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
460- ov_model = convert_model (
461- ts_decoder ,
462- example_input = dummy_inputs ,
463- input = [(item .shape , item .type ) for item in input_info ],
464- )
465- except Exception as ex :
466- logger .warning (f"Export model to OpenVINO directly failed with: \n { ex } .\n Model will be exported to ONNX" )
467-
468- if stateful :
469- # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
470- # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
471- logger .warning (
472- "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
473- "A stateless model will be exported instead. It may result in sub-optimal inference performance."
474- "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
475- )
476-
420+ # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
421+ # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
422+ # To handle it, additional wrapper on patcher forward applied.
423+ # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
424+ patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
425+ patched_forward = patcher .patched_forward
426+ dummy_input_keys = list (dummy_inputs .keys ())
427+
428+ @functools .wraps (patched_forward )
429+ def ts_patched_forward (* args , ** kwargs ):
430+ ordered_example_inputs = [
431+ param
432+ for param in inspect .signature (
433+ patcher .orig_forward if library_name != "sentence_transformers" else patcher .patched_forward
434+ ).parameters
435+ if param in dummy_input_keys
436+ ]
437+ kwargs .update (zip (ordered_example_inputs , args ))
438+ for i in range (len (dict_inputs )):
439+ input_name , keys = dict_inputs [i ]
440+ tuple_input = kwargs [input_name ]
441+ input_dict = dict (zip (keys , tuple_input ))
442+ kwargs [input_name ] = input_dict
443+ outputs = patched_forward (** kwargs )
444+ return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
445+
446+ patcher .patched_forward = ts_patched_forward
447+
448+ ts_decoder_kwargs = {}
449+ model_config = getattr (model , "config" , {})
450+ model_type = getattr (model_config , "model_type" , "" ).replace ("_" , "-" )
451+ if allow_skip_tracing_check (library_name , model_type ):
452+ ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
453+
454+ with patcher :
477455 if patch_16bit_model :
478- from openvino .frontend .pytorch .patch_model import unpatch_model
479-
480- unpatch_model (model , "_openvino_module_extension_patch_orig_forward" )
481- for m in model .modules ():
482- if any (p .dtype in [torch .float16 , torch .bfloat16 ] for p in m .parameters (False )) or any (
483- b .dtype in [torch .float16 , torch .bfloat16 ] for b in m .buffers (False )
484- ):
485- m .float ()
486-
487- return export_pytorch_via_onnx (
488- model ,
489- config ,
490- opset ,
491- output ,
492- device ,
493- input_shapes ,
494- model_kwargs ,
495- ov_config = ov_config ,
496- library_name = library_name ,
456+ from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
457+
458+ __make_16bit_traceable (model )
459+ check_dummy_inputs_are_allowed (model , dummy_inputs )
460+ input_info = _get_input_info (model , config , dummy_inputs )
461+ ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
462+ ov_model = convert_model (
463+ ts_decoder ,
464+ example_input = dummy_inputs ,
465+ input = [(item .shape , item .type ) for item in input_info ],
497466 )
498467
499468 ov_model .validate_nodes_and_infer_types () # TODO: remove as unnecessary validation?
0 commit comments