Skip to content

Commit 08e9dd2

Browse files
titaiwangmspytorchmergebot
authored andcommitted
[ONNX] Support symbolic arguments in onnx exporter (pytorch#157734)
Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing: ```python class M(torch.nn.Module): def forward(self, a, x): return a + torch.tensor(1) + x op = torch.onnx.export(M(), (1, torch.ones(2)), dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}), dynamo=True, report=True) ``` symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: pytorch#99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing. NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted. The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs pytorch#157719 Pull Request resolved: pytorch#157734 Approved by: https://github.com/justinchuby
1 parent 163f0d8 commit 08e9dd2

File tree

4 files changed

+120
-64
lines changed

4 files changed

+120
-64
lines changed

test/onnx/exporter/test_api.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import io
7+
import logging
78
import os
89

910
import numpy as np
@@ -73,6 +74,68 @@ def test_args_normalization_with_no_kwargs(self):
7374
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
7475
)
7576

77+
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
78+
class constant_plus_tensor_inputs(torch.nn.Module):
79+
def forward(self, a, x):
80+
return a + torch.tensor(1) + x
81+
82+
# Capture log output
83+
log_capture = io.StringIO()
84+
log_handler = logging.StreamHandler(log_capture)
85+
log_handler.setLevel(logging.ERROR)
86+
# Get the logger used in _core.py
87+
logger = logging.getLogger("torch.onnx._internal.exporter._core")
88+
original_level = logger.level
89+
logger.addHandler(log_handler)
90+
logger.setLevel(logging.ERROR)
91+
92+
try:
93+
with common_utils.TemporaryDirectoryName() as temp_dir:
94+
self.assert_export(
95+
constant_plus_tensor_inputs(),
96+
(
97+
1,
98+
torch.ones(2),
99+
),
100+
dynamic_shapes=(
101+
torch.export.Dim.DYNAMIC,
102+
{0: torch.export.Dim.DYNAMIC},
103+
),
104+
report=True,
105+
artifacts_dir=temp_dir,
106+
)
107+
# Check if the expected error was logged
108+
log_output = log_capture.getvalue()
109+
self.assertNotIn("Failed to save report due to an error", log_output)
110+
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
111+
# Note: We don't call assert_onnx_program here because it will fail
112+
# due to the input name mismatch issue mentioned in your error
113+
114+
finally:
115+
# Clean up logging
116+
logger.removeHandler(log_handler)
117+
logger.setLevel(original_level)
118+
119+
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
120+
class constant_plus_tensor_inputs(torch.nn.Module):
121+
def forward(self, a, x):
122+
return a + torch.tensor(1) + x
123+
124+
onnx_program = torch.onnx.export(
125+
constant_plus_tensor_inputs(),
126+
(
127+
1,
128+
torch.ones(2),
129+
),
130+
dynamic_shapes=(
131+
None,
132+
{0: torch.export.Dim.DYNAMIC},
133+
),
134+
dynamo=True,
135+
)
136+
137+
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
138+
76139
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
77140
self.assert_export(
78141
SampleModelForDynamicShapes(),

torch/onnx/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def _to_dynamic_shape(x):
520520
dynamic_shape[i] = torch.export.Dim.AUTO
521521
return dynamic_shape
522522
else:
523-
return None
523+
return torch.export.Dim.AUTO
524524

525525
# model_args could be nested
526526
dynamic_shapes = _pytree.tree_map(
@@ -529,7 +529,6 @@ def _to_dynamic_shape(x):
529529
)
530530
else:
531531
dynamic_shapes = None
532-
533532
return _compat.export_compat(
534533
model, # type: ignore[arg-type]
535534
model_args,

torch/onnx/_internal/exporter/_analysis.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,22 +159,48 @@ def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict,
159159
for spec in exported_program.graph_signature.output_specs
160160
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
161161
]
162-
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
163-
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
162+
inputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
163+
outputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {}
164164
for spec in user_inputs:
165-
if isinstance(spec.arg, graph_signature.ConstantArgument):
166-
continue
167-
name = spec.arg.name
168-
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
169-
inputs[name] = nodes[name].meta["tensor_meta"]
165+
inputs = _log_spec_into_io_specs(spec, nodes, inputs)
170166
for spec in user_outputs:
171-
if isinstance(spec.arg, graph_signature.ConstantArgument):
172-
continue
173-
name = spec.arg.name
174-
outputs[name] = nodes[name].meta["tensor_meta"]
167+
outputs = _log_spec_into_io_specs(spec, nodes, outputs)
175168
return inputs, outputs
176169

177170

171+
def _log_spec_into_io_specs(
172+
spec: graph_signature.InputSpec,
173+
nodes: dict[str, torch.fx.Node],
174+
inputs_or_outputs: dict[str, torch._export.serde.schema.TensorMeta | str],
175+
) -> dict[str, torch._export.serde.schema.TensorMeta | str]:
176+
# If dynamic is set to a constant input, it becomes a
177+
# symbolic argument, which is not a tensor.
178+
if isinstance(spec.arg, graph_signature.ConstantArgument):
179+
# Constant input does not have tensor_meta.
180+
return inputs_or_outputs
181+
# Symbolic arguments are not tensors, so it does not have tensor_meta,
182+
# but we need to provide a string representation for them to inform users.
183+
name = spec.arg.name
184+
if isinstance(
185+
spec.arg,
186+
(
187+
graph_signature.SymIntArgument,
188+
graph_signature.SymFloatArgument,
189+
graph_signature.SymBoolArgument,
190+
),
191+
):
192+
argument_to_str: dict[type[graph_signature.ArgumentSpec], str] = {
193+
graph_signature.SymIntArgument: "SymInt",
194+
graph_signature.SymFloatArgument: "SymFloat",
195+
graph_signature.SymBoolArgument: "SymBool",
196+
}
197+
inputs_or_outputs[name] = argument_to_str[type(spec.arg)]
198+
return inputs_or_outputs
199+
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
200+
inputs_or_outputs[name] = nodes[name].meta["tensor_meta"]
201+
return inputs_or_outputs
202+
203+
178204
def _count_fx_targets(
179205
exported_program: torch.export.ExportedProgram,
180206
) -> defaultdict[str, int]:

torch/onnx/_internal/exporter/_onnx_program.py

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import warnings
1616
from typing import Any, Callable, TYPE_CHECKING
1717

18+
import numpy as np
19+
1820
import torch
1921
from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir
2022
from torch.onnx._internal.exporter import _dynamic_shapes, _ir_passes
@@ -117,31 +119,40 @@ def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
117119
return values
118120

119121

120-
def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue:
122+
def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue:
121123
"""Convert a PyTorch tensor to an ONNX Runtime OrtValue."""
122124
import onnxruntime as ort
123125

124126
from torch.onnx._internal.exporter import _core
125127

126-
if tensor.dtype == torch.bfloat16 or tensor.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
128+
if isinstance(input, (int, float, str, bool)):
129+
# Convert scalar values to OrtValue
130+
dtype_mapping = {
131+
int: np.int64,
132+
float: np.float32,
133+
}
134+
dtype = dtype_mapping.get(type(input), None)
135+
return ort.OrtValue.ortvalue_from_numpy(np.array(input, dtype=dtype))
136+
137+
if input.dtype == torch.bfloat16 or input.dtype in _NP_UNSUPPORTED_DTYPES_8BIT:
127138
if hasattr(ort.OrtValue, "ortvalue_from_numpy_with_onnx_type"):
128139
# This requires ONNX Runtime 1.21 or newer
129-
if tensor.dtype == torch.bfloat16:
140+
if input.dtype == torch.bfloat16:
130141
uint_type = torch.uint16
131142
else:
132143
uint_type = torch.uint8
133-
onnx_type = _core.torch_dtype_to_onnx_dtype(tensor.dtype)
144+
onnx_type = _core.torch_dtype_to_onnx_dtype(input.dtype)
134145
# Make tensor contiguous to ensure view() works
135-
tensor = tensor.contiguous()
146+
input = input.contiguous()
136147
return ort.OrtValue.ortvalue_from_numpy_with_onnx_type(
137-
tensor.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
148+
input.view(uint_type).numpy(force=True), onnx_element_type=onnx_type
138149
)
139150
raise RuntimeError(
140-
f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. "
151+
f"Failed to convert tensor of type '{input.dtype}' to OrtValue. "
141152
"Please ensure that ONNX Runtime is built with DLPack support or is the latest version"
142153
)
143154
# TODO(#151064): Use dlpack when ORT properly supports it
144-
return ort.OrtValue.ortvalue_from_numpy(tensor.numpy(force=True))
155+
return ort.OrtValue.ortvalue_from_numpy(input.numpy(force=True))
145156

146157

147158
def _from_ort_value(value: ort.OrtValue) -> torch.Tensor:
@@ -208,7 +219,6 @@ def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]:
208219

209220
assert self._inference_session is not None
210221

211-
# We don't expect non-tensor as inputs
212222
ort_input = {
213223
k.name: _to_ort_value(v)
214224
for k, v in zip(self.model.graph.inputs, flatten_args)
@@ -414,7 +424,6 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:
414424
"""Process input arguments for the ONNX model."""
415425
args = _flatten_inputs(args, kwargs)
416426
args = _remove_none_from_inputs(args)
417-
args = _remove_non_tensor(args)
418427
args = _convert_complex_to_real_representation(args)
419428
return args
420429

@@ -428,47 +437,6 @@ def _remove_none_from_inputs(model_args):
428437
return tuple(arg for arg in model_args if arg is not None)
429438

430439

431-
def _remove_non_tensor(model_args):
432-
"""Remove the non-tensor input arguments.
433-
434-
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
435-
436-
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
437-
The concrete value is embedded into the graph as a constant arg of a target node. Meta
438-
suggests in this case that one should rewrite the model code to make it tensor if the
439-
input value is supposed to change at runtime. We might need to further investigate
440-
the feasibility of that suggestion.
441-
442-
For example,
443-
444-
def func(x, b=1.0):
445-
y = x + b
446-
z = y.relu()
447-
return (y, z)
448-
449-
x = torch.randn(1, 1, 2, dtype=torch.float32)
450-
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
451-
452-
# class GraphModule(torch.nn.Module):
453-
# def forward(self, x, b):
454-
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
455-
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
456-
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
457-
458-
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
459-
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
460-
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
461-
462-
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
463-
it's ignored in ONNX graph. Thus, we delete the useless input here.
464-
465-
"""
466-
467-
return tuple(
468-
arg for arg in model_args if not isinstance(arg, (int, float, bool, str))
469-
)
470-
471-
472440
def _convert_complex_to_real_representation(model_args):
473441
"""Convert complex dtype tensors to real representation tensors.
474442

0 commit comments

Comments
 (0)