Skip to content

Commit 8f06364

Browse files
authored
Migrate onnxscript converter to use onnx ir (#2706)
Migrate onnxscript converter to use onnx ir. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent d9802df commit 8f06364

File tree

9 files changed

+538
-778
lines changed

9 files changed

+538
-778
lines changed

onnxscript/_internal/autocast.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,27 @@ def get_type_info(x):
189189
def static_cast_inputs(
190190
converter_: converter.Converter,
191191
op_schema: Optional[OpSchema],
192-
args: Sequence[Optional[converter.Variable]],
192+
args: Sequence[Optional[ir.Value]],
193193
) -> tuple[str, ...]:
194194
"""Used for autocast during script-translation.
195195
This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))"
196196
Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed.
197197
"""
198198

199-
def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]:
199+
def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]:
200200
"""Returns x back if x can serve as the target-type for a cast (as the second
201201
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
202202
castable, while X can serve as the target-type.
203203
"""
204-
return None if x is None or x.is_castable else x
204+
return None if x is None or converter_.is_castable(x.name) else x
205205

206-
def cast_like(
207-
x: Optional[converter.Variable], y: Optional[converter.Variable]
208-
) -> Optional[str]:
206+
def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
209207
if x is None:
210208
return None
211-
if x.is_castable and y is not None:
209+
if converter_.is_castable(x.name) and y is not None:
212210
# Polymorphic constant x is cast to the type of y:
213211
x_cast = converter_.generate_unique_name(f"{x.name}_cast")
214-
converter_.emit([x_cast], "CastLike", [x.name, y.name])
215-
return x_cast
216-
return x.name
212+
return converter_.emit1([x_cast], "CastLike", [x, y])
213+
return x
217214

218215
return cast_inputs(get_type_info, cast_like, op_schema, args)

onnxscript/_internal/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import onnx
10+
import onnx_ir as ir
1011

1112
from onnxscript import tensor
1213

@@ -87,6 +88,41 @@ def value_to_type_proto(val):
8788
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
8889

8990

91+
def value_to_type(val):
92+
"""Return an ir.Value representation of a python-value."""
93+
if isinstance(val, (np.ndarray, tensor.Tensor)):
94+
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
95+
shape = val.shape
96+
return (ir.TensorType(elem_type), shape)
97+
elif isinstance(val, int):
98+
elem_type = onnx.TensorProto.INT32
99+
shape = []
100+
return (ir.TensorType(elem_type), shape)
101+
elif isinstance(val, (float, np.float32)):
102+
elem_type = onnx.TensorProto.FLOAT
103+
shape = []
104+
return (ir.TensorType(elem_type), shape)
105+
elif isinstance(val, list):
106+
if len(val) > 0:
107+
type, shape = value_to_type(val[0])
108+
return ir.SequenceType(type), shape
109+
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
110+
# Should be using a typed-value instead.
111+
# Treated as a sequence of tensors of float-type.
112+
return ir.SequenceType(ir.TensorType(onnx.TensorProto.FLOAT)), None
113+
if isinstance(val, numbers.Number):
114+
nparray = np.array(val)
115+
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
116+
return ir.TensorType(elem_type), []
117+
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
118+
119+
120+
def value_to_ir_value(name: str, val) -> ir.Value:
121+
"""Return an ir.Value representation of a python-value."""
122+
type, shape = value_to_type(val)
123+
return ir.Value(name=name, type=type, shape=shape)
124+
125+
90126
def values_to_value_infos(name_values):
91127
"""Create a list of ValueInfoProto from a list of (name, value) pairs,
92128
skipping any None values.
@@ -96,3 +132,10 @@ def values_to_value_infos(name_values):
96132
for (name, val) in name_values
97133
if val is not None
98134
]
135+
136+
137+
def values_to_ir_values(name_values):
138+
"""Create a list of ir.Value from a list of (name, value) pairs,
139+
skipping any None values.
140+
"""
141+
return [value_to_ir_value(name, val) for (name, val) in name_values if val is not None]

0 commit comments

Comments
 (0)