Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@
MULTI_MODAL_TEXT_GENERATION_MODELS,
OV_XML_FILE_NAME,
_get_input_info,
_get_dynamic_shapes_info,
_normalize_dummy_inputs,
_get_open_clip_submodels_fn_and_export_configs,
get_model_dtype,
allow_skip_tracing_check,
clear_class_registry,
remove_none_from_dummy_inputs,
Expand Down Expand Up @@ -425,9 +428,86 @@ def export_pytorch(
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward
#patched_forward = patcher.orig_forward
import inspect
from optimum.exporters.onnx.model_patcher import override_arguments
dummy_input_keys = list(dummy_inputs.keys())

if is_transformers_version(">=", "4.48"):
from transformers.cache_utils import DynamicCache, EncoderDecoderCache

@functools.wraps(patcher.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(patcher.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=patcher.model_kwargs)

if is_transformers_version(">=", "4.48"):
if "past_key_values" in signature.parameters:
pkv_index = list(signature.parameters.keys()).index("past_key_values")

if (
pkv_index < len(args) # pkv is in args
and isinstance(args[pkv_index], (list, tuple))
and isinstance(args[pkv_index][0], (list, tuple))
):
if len(args[pkv_index][0]) == 2:
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
elif len(args[pkv_index][0]) == 4:
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
)
elif (
"past_key_values" in kwargs # pkv is in kwargs
and isinstance(kwargs["past_key_values"], (list, tuple))
and isinstance(kwargs["past_key_values"][0], (list, tuple))
):
if len(kwargs["past_key_values"][0]) == 2:
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
elif len(kwargs["past_key_values"][0]) == 4:
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
kwargs["past_key_values"]
)
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
)

outputs = patcher.orig_forward(*args, **kwargs)

# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filterd_outputs = {}
if isinstance(outputs, dict):
for name, value in outputs.items():
filterd_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filterd_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
outputs_str = ", ".join(config.outputs.keys())
raise ValueError(
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
)
else:
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs

if is_transformers_version(">=", "4.48"):
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

return filterd_outputs

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
ordered_example_inputs = [
Expand Down Expand Up @@ -455,18 +535,38 @@ def ts_patched_forward(*args, **kwargs):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
use_export = True
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)
if use_export:
if hasattr(torch.ops, "_prepare_4d_causal_attention_mask_for_sdpa"):
# patch_everywhere breaks torch.ops namespace
del torch.ops._prepare_4d_causal_attention_mask_for_sdpa
dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs)
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, get_model_dtype(model))}
_export_kwargs["dynamic_shapes"] = dynamic_shapes

try:
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
# nncf patching breaks export
with disable_patching():
ep = torch.export.export_for_training(model, **_export_kwargs)
except ImportError:
ep = torch.export.export_for_training(model, **_export_kwargs)

ov_model = convert_model(ep)
else:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)

ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?

Expand Down
66 changes: 66 additions & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
import logging
import re
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -121,6 +122,71 @@ def _get_input_info(
return input_info


def _get_dynamic_shapes_info(
model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]
) -> List[InputInfo]:
import torch

sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
input_info = {}
signature = set(sig.parameters)

name_to_symbol = {}

for name, named_dims in inputs.items():
info = {}
for idx, dim_name in named_dims.items():
if dim_name in name_to_symbol:
symbol = name_to_symbol[dim_name]
else:
symbol = torch.export.Dim.DYNAMIC
name_to_symbol[dim_name] = symbol
info[idx] = symbol
if name in signature:
input_info[name] = info
else:
pattern = r"^([a-zA-Z_]+)\.(\d+)\.(key|value)$"
match = re.match(pattern, name)

if match:
prefix, number, key_or_value = match.groups()
number = int(number)
assert prefix in signature
if prefix not in input_info:
input_info[prefix] = []
if key_or_value == "key":
assert len(input_info[prefix]) == number
input_info[prefix].append((info,))
else:
input_info[prefix][number] += (info,)
return input_info


def _normalize_element(elem: Any, dtype: Any) -> Any:
import torch
if isinstance(elem, torch.Tensor):
return elem.to(dtype) if elem.dtype.is_floating_point else elem
if isinstance(elem, (list, tuple)):
return type(elem)(_normalize_element(e, dtype) for e in elem)
if isinstance(elem, dict):
return {k: _normalize_element(v, dtype) for k, v in elem.items()}
return elem


def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[str, Any]:
new_dummy = {}
for name, value in dummy_inputs.items():
new_dummy[name] = _normalize_element(value, dtype)
return new_dummy


def get_model_dtype(model):
for param in model.parameters():
return param.dtype
return getattr(model, "dtype", torch.float32)


def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
"""
Removes None values from the dictionary.
Expand Down
Loading