Skip to content

Commit cbb9400

Browse files
committed
[POC] Use torch.export for converting
1 parent fc76020 commit cbb9400

File tree

2 files changed

+173
-12
lines changed

2 files changed

+173
-12
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
MULTI_MODAL_TEXT_GENERATION_MODELS,
6666
OV_XML_FILE_NAME,
6767
_get_input_info,
68+
_get_dynamic_shapes_info,
69+
_normalize_dummy_inputs,
6870
_get_open_clip_submodels_fn_and_export_configs,
6971
allow_skip_tracing_check,
7072
clear_class_registry,
@@ -422,7 +424,84 @@ def export_pytorch(
422424
# To handle it, additional wrapper on patcher forward applied.
423425
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
424426
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
425-
patched_forward = patcher.patched_forward
427+
#patched_forward = patcher.orig_forward
428+
import inspect
429+
from optimum.exporters.onnx.model_patcher import override_arguments
430+
431+
if is_transformers_version(">=", "4.48"):
432+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
433+
434+
@functools.wraps(patcher.orig_forward)
435+
def patched_forward(*args, **kwargs):
436+
signature = inspect.signature(patcher.orig_forward)
437+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=patcher.model_kwargs)
438+
439+
if is_transformers_version(">=", "4.48"):
440+
if "past_key_values" in signature.parameters:
441+
pkv_index = list(signature.parameters.keys()).index("past_key_values")
442+
443+
if (
444+
pkv_index < len(args) # pkv is in args
445+
and isinstance(args[pkv_index], (list, tuple))
446+
and isinstance(args[pkv_index][0], (list, tuple))
447+
):
448+
if len(args[pkv_index][0]) == 2:
449+
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
450+
elif len(args[pkv_index][0]) == 4:
451+
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
452+
else:
453+
raise ValueError(
454+
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
455+
)
456+
elif (
457+
"past_key_values" in kwargs # pkv is in kwargs
458+
and isinstance(kwargs["past_key_values"], (list, tuple))
459+
and isinstance(kwargs["past_key_values"][0], (list, tuple))
460+
):
461+
if len(kwargs["past_key_values"][0]) == 2:
462+
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
463+
elif len(kwargs["past_key_values"][0]) == 4:
464+
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
465+
kwargs["past_key_values"]
466+
)
467+
else:
468+
raise ValueError(
469+
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
470+
)
471+
472+
outputs = patcher.orig_forward(*args, **kwargs)
473+
474+
# This code block handles different cases of the filterd_outputs input to align it with the expected
475+
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
476+
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
477+
# contains the output names of the model. In the case of Timm classification models, the output
478+
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
479+
# match the outputs in order.
480+
filterd_outputs = {}
481+
if isinstance(outputs, dict):
482+
for name, value in outputs.items():
483+
filterd_outputs[name] = value
484+
elif isinstance(outputs, (list, tuple)):
485+
outputs_list = list(config.outputs.keys())
486+
filterd_outputs = dict(zip(outputs_list, outputs))
487+
else:
488+
if len(config.outputs) > 1:
489+
num_outputs = len(config.outputs)
490+
outputs_str = ", ".join(config.outputs.keys())
491+
raise ValueError(
492+
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
493+
)
494+
else:
495+
name = list(config.outputs.keys())[0]
496+
filterd_outputs[name] = outputs
497+
name = list(config.outputs.keys())[0]
498+
filterd_outputs[name] = outputs
499+
500+
if is_transformers_version(">=", "4.48"):
501+
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
502+
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
503+
504+
return filterd_outputs
426505

427506
@functools.wraps(patched_forward)
428507
def ts_patched_forward(*args, **kwargs):
@@ -443,21 +522,43 @@ def ts_patched_forward(*args, **kwargs):
443522
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
444523

445524
with patcher:
446-
if patch_16bit_model:
447-
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
448-
449-
__make_16bit_traceable(model)
525+
use_export = True
450526
check_dummy_inputs_are_allowed(model, dummy_inputs)
451527
input_info = _get_input_info(model, config, dummy_inputs)
452-
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
453-
ov_model = convert_model(
454-
ts_decoder,
455-
example_input=dummy_inputs,
456-
input=[(item.shape, item.type) for item in input_info],
457-
)
528+
if use_export:
529+
if hasattr(torch.ops, "_prepare_4d_causal_attention_mask_for_sdpa"):
530+
# patch_everywhere breaks torch.ops namespace
531+
del torch.ops._prepare_4d_causal_attention_mask_for_sdpa
532+
dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs)
533+
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, model.dtype)}
534+
_export_kwargs["dynamic_shapes"] = dynamic_shapes
535+
536+
try:
537+
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
538+
# nncf patching breaks export
539+
with disable_patching():
540+
ep = torch.export.export_for_training(model, **_export_kwargs)
541+
except ImportError:
542+
ep = torch.export.export_for_training(model, **_export_kwargs)
543+
544+
ov_model = convert_model(ep)
545+
else:
546+
if patch_16bit_model:
547+
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
548+
549+
__make_16bit_traceable(model)
550+
551+
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
552+
ov_model = convert_model(
553+
ts_decoder,
554+
example_input=dummy_inputs,
555+
input=[(item.shape, item.type) for item in input_info],
556+
)
458557

459558
except Exception as ex:
460-
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
559+
logger.warning(f"Export model to OpenVINO directly failed with: \n", exc_info=ex)
560+
raise ex
561+
logger.warning("\nModel will be exported to ONNX")
461562

462563
if stateful:
463564
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly

optimum/exporters/openvino/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
import logging
17+
import re
1718
from collections import namedtuple
1819
from pathlib import Path
1920
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -121,6 +122,65 @@ def _get_input_info(
121122
return input_info
122123

123124

125+
def _get_dynamic_shapes_info(
126+
model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]
127+
) -> List[InputInfo]:
128+
import torch
129+
130+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
131+
inputs = config.ordered_inputs(model)
132+
input_info = {}
133+
signature = set(sig.parameters)
134+
135+
name_to_symbol = {}
136+
137+
for name, named_dims in inputs.items():
138+
info = {}
139+
for idx, dim_name in named_dims.items():
140+
if dim_name in name_to_symbol:
141+
symbol = name_to_symbol[dim_name]
142+
else:
143+
symbol = torch.export.Dim.DYNAMIC
144+
name_to_symbol[dim_name] = symbol
145+
info[idx] = symbol
146+
if name in signature:
147+
input_info[name] = info
148+
else:
149+
pattern = r"^([a-zA-Z_]+)\.(\d+)\.(key|value)$"
150+
match = re.match(pattern, name)
151+
152+
if match:
153+
prefix, number, key_or_value = match.groups()
154+
number = int(number)
155+
assert prefix in signature
156+
if prefix not in input_info:
157+
input_info[prefix] = []
158+
if key_or_value == "key":
159+
assert len(input_info[prefix]) == number
160+
input_info[prefix].append((info,))
161+
else:
162+
input_info[prefix][number] += (info,)
163+
return input_info
164+
165+
166+
def _normalize_element(elem: Any, dtype: Any) -> Any:
167+
import torch
168+
if isinstance(elem, torch.Tensor):
169+
return elem.to(dtype) if elem.dtype.is_floating_point else elem
170+
if isinstance(elem, (list, tuple)):
171+
return type(elem)(_normalize_element(e, dtype) for e in elem)
172+
if isinstance(elem, dict):
173+
return {k: _normalize_element(v, dtype) for k, v in elem.items()}
174+
return elem
175+
176+
177+
def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[str, Any]:
178+
new_dummy = {}
179+
for name, value in dummy_inputs.items():
180+
new_dummy[name] = _normalize_element(value, dtype)
181+
return new_dummy
182+
183+
124184
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
125185
"""
126186
Removes None values from the dictionary.

0 commit comments

Comments
 (0)