Skip to content

Commit 5ac5e5d

Browse files
committed
[POC] Use torch.export for converting
1 parent 3b2cffb commit 5ac5e5d

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
MULTI_MODAL_TEXT_GENERATION_MODELS,
6868
OV_XML_FILE_NAME,
6969
_get_input_info,
70+
_get_dynamic_shapes_info,
71+
_normalize_dummy_inputs,
7072
_get_open_clip_submodels_fn_and_export_configs,
7173
allow_skip_tracing_check,
7274
clear_class_registry,
@@ -425,6 +427,7 @@ def export_pytorch(
425427
patched_forward = patcher.patched_forward
426428
dummy_input_keys = list(dummy_inputs.keys())
427429

430+
<<<<<<< HEAD
428431
@functools.wraps(patched_forward)
429432
def ts_patched_forward(*args, **kwargs):
430433
ordered_example_inputs = [
@@ -442,14 +445,158 @@ def ts_patched_forward(*args, **kwargs):
442445
kwargs[input_name] = input_dict
443446
outputs = patched_forward(**kwargs)
444447
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
448+
=======
449+
try:
450+
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
451+
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
452+
# To handle it, additional wrapper on patcher forward applied.
453+
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
454+
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
455+
#patched_forward = patcher.orig_forward
456+
import inspect
457+
from optimum.exporters.onnx.model_patcher import override_arguments
458+
459+
if is_transformers_version(">=", "4.48"):
460+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
461+
462+
@functools.wraps(patcher.orig_forward)
463+
def patched_forward(*args, **kwargs):
464+
signature = inspect.signature(patcher.orig_forward)
465+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=patcher.model_kwargs)
466+
467+
if is_transformers_version(">=", "4.48"):
468+
if "past_key_values" in signature.parameters:
469+
pkv_index = list(signature.parameters.keys()).index("past_key_values")
470+
471+
if (
472+
pkv_index < len(args) # pkv is in args
473+
and isinstance(args[pkv_index], (list, tuple))
474+
and isinstance(args[pkv_index][0], (list, tuple))
475+
):
476+
if len(args[pkv_index][0]) == 2:
477+
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
478+
elif len(args[pkv_index][0]) == 4:
479+
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
480+
else:
481+
raise ValueError(
482+
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
483+
)
484+
elif (
485+
"past_key_values" in kwargs # pkv is in kwargs
486+
and isinstance(kwargs["past_key_values"], (list, tuple))
487+
and isinstance(kwargs["past_key_values"][0], (list, tuple))
488+
):
489+
if len(kwargs["past_key_values"][0]) == 2:
490+
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
491+
elif len(kwargs["past_key_values"][0]) == 4:
492+
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
493+
kwargs["past_key_values"]
494+
)
495+
else:
496+
raise ValueError(
497+
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
498+
)
499+
500+
outputs = patcher.orig_forward(*args, **kwargs)
501+
502+
# This code block handles different cases of the filterd_outputs input to align it with the expected
503+
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
504+
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
505+
# contains the output names of the model. In the case of Timm classification models, the output
506+
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
507+
# match the outputs in order.
508+
filterd_outputs = {}
509+
if isinstance(outputs, dict):
510+
for name, value in outputs.items():
511+
filterd_outputs[name] = value
512+
elif isinstance(outputs, (list, tuple)):
513+
outputs_list = list(config.outputs.keys())
514+
filterd_outputs = dict(zip(outputs_list, outputs))
515+
else:
516+
if len(config.outputs) > 1:
517+
num_outputs = len(config.outputs)
518+
outputs_str = ", ".join(config.outputs.keys())
519+
raise ValueError(
520+
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
521+
)
522+
else:
523+
name = list(config.outputs.keys())[0]
524+
filterd_outputs[name] = outputs
525+
name = list(config.outputs.keys())[0]
526+
filterd_outputs[name] = outputs
527+
528+
if is_transformers_version(">=", "4.48"):
529+
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
530+
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
531+
532+
return filterd_outputs
533+
>>>>>>> cfde44f ([POC] Use torch.export for converting)
445534

446535
patcher.patched_forward = ts_patched_forward
447536

537+
<<<<<<< HEAD
448538
ts_decoder_kwargs = {}
449539
model_config = getattr(model, "config", {})
450540
model_type = getattr(model_config, "model_type", "").replace("_", "-")
451541
if allow_skip_tracing_check(library_name, model_type):
452542
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
543+
=======
544+
patcher.patched_forward = ts_patched_forward
545+
546+
ts_decoder_kwargs = {}
547+
model_config = getattr(model, "config", {})
548+
model_type = getattr(model_config, "model_type", "").replace("_", "-")
549+
if allow_skip_tracing_check(library_name, model_type):
550+
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
551+
552+
with patcher:
553+
use_export = True
554+
check_dummy_inputs_are_allowed(model, dummy_inputs)
555+
input_info = _get_input_info(model, config, dummy_inputs)
556+
if use_export:
557+
if hasattr(torch.ops, "_prepare_4d_causal_attention_mask_for_sdpa"):
558+
# patch_everywhere breaks torch.ops namespace
559+
del torch.ops._prepare_4d_causal_attention_mask_for_sdpa
560+
dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs)
561+
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, model.dtype)}
562+
_export_kwargs["dynamic_shapes"] = dynamic_shapes
563+
564+
try:
565+
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
566+
# nncf patching breaks export
567+
with disable_patching():
568+
ep = torch.export.export_for_training(model, **_export_kwargs)
569+
except ImportError:
570+
ep = torch.export.export_for_training(model, **_export_kwargs)
571+
572+
ov_model = convert_model(ep)
573+
else:
574+
if patch_16bit_model:
575+
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
576+
577+
__make_16bit_traceable(model)
578+
579+
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
580+
ov_model = convert_model(
581+
ts_decoder,
582+
example_input=dummy_inputs,
583+
input=[(item.shape, item.type) for item in input_info],
584+
)
585+
586+
except Exception as ex:
587+
logger.warning(f"Export model to OpenVINO directly failed with: \n", exc_info=ex)
588+
raise ex
589+
logger.warning("\nModel will be exported to ONNX")
590+
591+
if stateful:
592+
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
593+
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
594+
logger.warning(
595+
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
596+
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
597+
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
598+
)
599+
>>>>>>> cfde44f ([POC] Use torch.export for converting)
453600

454601
with patcher:
455602
if patch_16bit_model:

optimum/exporters/openvino/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
import logging
17+
import re
1718
from collections import namedtuple
1819
from pathlib import Path
1920
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -121,6 +122,65 @@ def _get_input_info(
121122
return input_info
122123

123124

125+
def _get_dynamic_shapes_info(
126+
model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]
127+
) -> List[InputInfo]:
128+
import torch
129+
130+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
131+
inputs = config.ordered_inputs(model)
132+
input_info = {}
133+
signature = set(sig.parameters)
134+
135+
name_to_symbol = {}
136+
137+
for name, named_dims in inputs.items():
138+
info = {}
139+
for idx, dim_name in named_dims.items():
140+
if dim_name in name_to_symbol:
141+
symbol = name_to_symbol[dim_name]
142+
else:
143+
symbol = torch.export.Dim.DYNAMIC
144+
name_to_symbol[dim_name] = symbol
145+
info[idx] = symbol
146+
if name in signature:
147+
input_info[name] = info
148+
else:
149+
pattern = r"^([a-zA-Z_]+)\.(\d+)\.(key|value)$"
150+
match = re.match(pattern, name)
151+
152+
if match:
153+
prefix, number, key_or_value = match.groups()
154+
number = int(number)
155+
assert prefix in signature
156+
if prefix not in input_info:
157+
input_info[prefix] = []
158+
if key_or_value == "key":
159+
assert len(input_info[prefix]) == number
160+
input_info[prefix].append((info,))
161+
else:
162+
input_info[prefix][number] += (info,)
163+
return input_info
164+
165+
166+
def _normalize_element(elem: Any, dtype: Any) -> Any:
167+
import torch
168+
if isinstance(elem, torch.Tensor):
169+
return elem.to(dtype) if elem.dtype.is_floating_point else elem
170+
if isinstance(elem, (list, tuple)):
171+
return type(elem)(_normalize_element(e, dtype) for e in elem)
172+
if isinstance(elem, dict):
173+
return {k: _normalize_element(v, dtype) for k, v in elem.items()}
174+
return elem
175+
176+
177+
def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[str, Any]:
178+
new_dummy = {}
179+
for name, value in dummy_inputs.items():
180+
new_dummy[name] = _normalize_element(value, dtype)
181+
return new_dummy
182+
183+
124184
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
125185
"""
126186
Removes None values from the dictionary.

0 commit comments

Comments
 (0)