diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 6df37c2d89..cc66706f7f 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -67,7 +67,10 @@ MULTI_MODAL_TEXT_GENERATION_MODELS, OV_XML_FILE_NAME, _get_input_info, + _get_dynamic_shapes_info, + _normalize_dummy_inputs, _get_open_clip_submodels_fn_and_export_configs, + get_model_dtype, allow_skip_tracing_check, clear_class_registry, remove_none_from_dummy_inputs, @@ -425,9 +428,86 @@ def export_pytorch( # To handle it, additional wrapper on patcher forward applied. # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs) - patched_forward = patcher.patched_forward + #patched_forward = patcher.orig_forward + import inspect + from optimum.exporters.onnx.model_patcher import override_arguments dummy_input_keys = list(dummy_inputs.keys()) + if is_transformers_version(">=", "4.48"): + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + + @functools.wraps(patcher.orig_forward) + def patched_forward(*args, **kwargs): + signature = inspect.signature(patcher.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=patcher.model_kwargs) + + if is_transformers_version(">=", "4.48"): + if "past_key_values" in signature.parameters: + pkv_index = list(signature.parameters.keys()).index("past_key_values") + + if ( + pkv_index < len(args) # pkv is in args + and isinstance(args[pkv_index], (list, tuple)) + and isinstance(args[pkv_index][0], (list, tuple)) + ): + if len(args[pkv_index][0]) == 2: + args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index]) + elif len(args[pkv_index][0]) == 4: + args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements" + ) + elif ( + "past_key_values" in kwargs # pkv is in kwargs + and isinstance(kwargs["past_key_values"], (list, tuple)) + and isinstance(kwargs["past_key_values"][0], (list, tuple)) + ): + if len(kwargs["past_key_values"][0]) == 2: + kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) + elif len(kwargs["past_key_values"][0]) == 4: + kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache( + kwargs["past_key_values"] + ) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" + ) + + outputs = patcher.orig_forward(*args, **kwargs) + + # This code block handles different cases of the filterd_outputs input to align it with the expected + # format of outputs. It is common for the output type of a model to vary, such as tensor, list, + # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that + # contains the output names of the model. In the case of Timm classification models, the output + # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config + # match the outputs in order. + filterd_outputs = {} + if isinstance(outputs, dict): + for name, value in outputs.items(): + filterd_outputs[name] = value + elif isinstance(outputs, (list, tuple)): + outputs_list = list(config.outputs.keys()) + filterd_outputs = dict(zip(outputs_list, outputs)) + else: + if len(config.outputs) > 1: + num_outputs = len(config.outputs) + outputs_str = ", ".join(config.outputs.keys()) + raise ValueError( + f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}" + ) + else: + name = list(config.outputs.keys())[0] + filterd_outputs[name] = outputs + name = list(config.outputs.keys())[0] + filterd_outputs[name] = outputs + + if is_transformers_version(">=", "4.48"): + if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): + filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + + return filterd_outputs + @functools.wraps(patched_forward) def ts_patched_forward(*args, **kwargs): ordered_example_inputs = [ @@ -455,18 +535,38 @@ def ts_patched_forward(*args, **kwargs): ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False} with patcher: - if patch_16bit_model: - from openvino.frontend.pytorch.patch_model import __make_16bit_traceable - - __make_16bit_traceable(model) + use_export = True check_dummy_inputs_are_allowed(model, dummy_inputs) input_info = _get_input_info(model, config, dummy_inputs) - ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs) - ov_model = convert_model( - ts_decoder, - example_input=dummy_inputs, - input=[(item.shape, item.type) for item in input_info], - ) + if use_export: + if hasattr(torch.ops, "_prepare_4d_causal_attention_mask_for_sdpa"): + # patch_everywhere breaks torch.ops namespace + del torch.ops._prepare_4d_causal_attention_mask_for_sdpa + dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs) + _export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, get_model_dtype(model))} + _export_kwargs["dynamic_shapes"] = dynamic_shapes + + try: + from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + # nncf patching breaks export + with disable_patching(): + ep = torch.export.export_for_training(model, **_export_kwargs) + except ImportError: + ep = torch.export.export_for_training(model, **_export_kwargs) + + ov_model = convert_model(ep) + else: + if patch_16bit_model: + from openvino.frontend.pytorch.patch_model import __make_16bit_traceable + + __make_16bit_traceable(model) + + ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs) + ov_model = convert_model( + ts_decoder, + example_input=dummy_inputs, + input=[(item.shape, item.type) for item in input_info], + ) ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation? diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index d1318fc109..afaa4741a7 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -14,6 +14,7 @@ import inspect import logging +import re from collections import namedtuple from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -121,6 +122,71 @@ def _get_input_info( return input_info +def _get_dynamic_shapes_info( + model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any] +) -> List[InputInfo]: + import torch + + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) + inputs = config.ordered_inputs(model) + input_info = {} + signature = set(sig.parameters) + + name_to_symbol = {} + + for name, named_dims in inputs.items(): + info = {} + for idx, dim_name in named_dims.items(): + if dim_name in name_to_symbol: + symbol = name_to_symbol[dim_name] + else: + symbol = torch.export.Dim.DYNAMIC + name_to_symbol[dim_name] = symbol + info[idx] = symbol + if name in signature: + input_info[name] = info + else: + pattern = r"^([a-zA-Z_]+)\.(\d+)\.(key|value)$" + match = re.match(pattern, name) + + if match: + prefix, number, key_or_value = match.groups() + number = int(number) + assert prefix in signature + if prefix not in input_info: + input_info[prefix] = [] + if key_or_value == "key": + assert len(input_info[prefix]) == number + input_info[prefix].append((info,)) + else: + input_info[prefix][number] += (info,) + return input_info + + +def _normalize_element(elem: Any, dtype: Any) -> Any: + import torch + if isinstance(elem, torch.Tensor): + return elem.to(dtype) if elem.dtype.is_floating_point else elem + if isinstance(elem, (list, tuple)): + return type(elem)(_normalize_element(e, dtype) for e in elem) + if isinstance(elem, dict): + return {k: _normalize_element(v, dtype) for k, v in elem.items()} + return elem + + +def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[str, Any]: + new_dummy = {} + for name, value in dummy_inputs.items(): + new_dummy[name] = _normalize_element(value, dtype) + return new_dummy + + +def get_model_dtype(model): + for param in model.parameters(): + return param.dtype + return getattr(model, "dtype", torch.float32) + + def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]): """ Removes None values from the dictionary.