Skip to content

Commit a5b0d86

Browse files
authored
Merge branch 'main' into titaiwang/raise_refattr_function_vc
2 parents 5ce524f + 2b2618e commit a5b0d86

File tree

19 files changed

+449
-550
lines changed

19 files changed

+449
-550
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.6.2
1+
0.6.3

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"packaging",
4242
"protobuf",
4343
)
44-
ONNX_IR = "onnx_ir==0.1.15"
44+
ONNX_IR = "onnx_ir==0.1.16"
4545
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4646

4747

onnxscript/_internal/autocast.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import onnx
1010

1111
from onnxscript import ir, tensor
12-
from onnxscript.ir import _schemas
1312

1413
if TYPE_CHECKING:
1514
from onnxscript._internal import converter
@@ -112,7 +111,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
112111
def cast_inputs(
113112
get_type_info: Callable[[Any], Any],
114113
cast: Callable[[Any, Any], Any],
115-
op_signature: _schemas.OpSignature | None,
114+
op_signature: ir.schemas.OpSignature | None,
116115
args,
117116
) -> tuple[Any, ...]:
118117
"""Uses schema specification to support a limited form of auto-casting.
@@ -164,7 +163,7 @@ def cast_inputs(
164163
return tuple(cast_args)
165164

166165

167-
def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args):
166+
def dynamic_cast_inputs(op_signature: ir.schemas.OpSignature, args):
168167
"""Used for autocast during eager-mode execution."""
169168

170169
def get_type_info(x):
@@ -175,7 +174,7 @@ def get_type_info(x):
175174

176175
def static_cast_inputs(
177176
converter_: converter.Converter,
178-
op_signature: Optional[_schemas.OpSignature],
177+
op_signature: Optional[ir.schemas.OpSignature],
179178
args: Sequence[Optional[ir.Value]],
180179
) -> tuple[str, ...]:
181180
"""Used for autocast during script-translation.

onnxscript/_internal/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def _translate_attr(
507507
self,
508508
attr_name: str,
509509
expr: ast.AST,
510-
attr_meta: _schemas.AttributeParameter | None = None,
510+
attr_meta: ir.schemas.AttributeParameter | None = None,
511511
) -> ir.Attr | None:
512512
"""Translate an attribute-value specification of the form `attr_name=<expr>`
513513
in a call to an op. expr is an AST. The following cases are supported:

onnxscript/_internal/evaluator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from onnxscript import onnx_opset, tensor
2828
from onnxscript._internal import autocast, param_manipulation, utils, values
29-
from onnxscript.ir import _schemas
3029

3130
UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]
3231

@@ -181,7 +180,7 @@ def __init__(self, ignore_unknown_function_kwargs: bool = False):
181180
self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs
182181

183182
def _adapt_inputs(
184-
self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue]
183+
self, op_signature: ir.schemas.OpSignature, inputs: Sequence[ExtendedModeValue]
185184
):
186185
"""Transform inputs to the expected format for the evaluator.
187186
@@ -225,7 +224,7 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]):
225224
"""
226225
return outputs[0] if len(outputs) == 1 else outputs
227226

228-
def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool:
227+
def use_graph_attribute(self, op_signature: ir.schemas.OpSignature) -> bool:
229228
del op_signature # unused
230229
return True
231230

@@ -292,15 +291,15 @@ def eval_function(
292291
adapted_kwargs: dict[str, ExtendedModeValue] = {}
293292
has_array = False
294293
for arg, param in tagged_args:
295-
if isinstance(param, _schemas.Parameter):
294+
if isinstance(param, ir.schemas.Parameter):
296295
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
297296
has_array = has_array or has_array_
298297
adapted_args.append(adapted_arg)
299298
else:
300299
adapted_args.append(arg)
301300

302301
for key, (arg, param) in tagged_kwargs.items():
303-
if isinstance(param, _schemas.Parameter):
302+
if isinstance(param, ir.schemas.Parameter):
304303
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
305304
has_array = has_array or has_array_
306305
adapted_kwargs[key] = adapted_arg
@@ -511,7 +510,7 @@ def _call_ort(
511510

512511

513512
def _op_identifier(
514-
op_schema_or_signature: onnx.defs.OpSchema | _schemas.OpSignature,
513+
op_schema_or_signature: onnx.defs.OpSchema | ir.schemas.OpSignature,
515514
) -> tuple[str, str, int]:
516515
return (
517516
op_schema_or_signature.name,
@@ -564,7 +563,7 @@ def __init__(self) -> None:
564563
super().__init__()
565564
self._python_ops: dict[tuple[str, str, int], Any] = {}
566565

567-
def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool:
566+
def use_graph_attribute(self, op_signature: ir.schemas.OpSignature) -> bool:
568567
return _op_identifier(op_signature) not in self._python_ops
569568

570569
def _eval(self, schema, inputs, attributes, closure):

onnxscript/_internal/param_manipulation.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import collections
88
from typing import Any, OrderedDict
99

10-
from onnxscript.ir import _schemas
10+
from onnxscript import ir
1111

1212

1313
def separate_input_attributes_from_arguments(
14-
op_signature: _schemas.OpSignature,
14+
op_signature: ir.schemas.OpSignature,
1515
args,
1616
kwargs,
1717
fill_defaults: bool = True,
@@ -48,8 +48,8 @@ def separate_input_attributes_from_arguments(
4848
onnx_attributes = collections.OrderedDict()
4949

5050
for i, param in enumerate(op_signature.params):
51-
is_input = isinstance(param, _schemas.Parameter)
52-
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
51+
is_input = param.is_param()
52+
is_variadic = is_input and param.variadic
5353

5454
if is_variadic:
5555
# Exhaust all remaining args
@@ -66,7 +66,7 @@ def separate_input_attributes_from_arguments(
6666
onnx_inputs.append(kwargs[param.name])
6767
else:
6868
onnx_attributes[param.name] = kwargs[param.name]
69-
elif isinstance(param, _schemas.AttributeParameter) and param.has_default():
69+
elif isinstance(param, ir.schemas.AttributeParameter) and param.has_default():
7070
# User did not provide the attribute
7171
if fill_defaults:
7272
# Extract the value from the Attr object
@@ -78,14 +78,14 @@ def separate_input_attributes_from_arguments(
7878

7979

8080
def tag_arguments_with_signature(
81-
op_signature: _schemas.OpSignature,
81+
op_signature: ir.schemas.OpSignature,
8282
args,
8383
kwargs,
8484
fill_defaults: bool = True,
8585
allow_extra_kwargs: bool = False,
8686
) -> tuple[
87-
list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
88-
dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
87+
list[tuple[Any, ir.schemas.Parameter | ir.schemas.AttributeParameter]],
88+
dict[str, tuple[Any, ir.schemas.Parameter | ir.schemas.AttributeParameter]],
8989
]:
9090
"""Tag Python args and kwargs with matching ONNX Parameter/AttributeParameter.
9191
@@ -115,11 +115,13 @@ def tag_arguments_with_signature(
115115
if extra_kwargs and not allow_extra_kwargs:
116116
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")
117117

118-
tagged_args: list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = []
119-
tagged_kwargs: dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = {}
118+
tagged_args: list[tuple[Any, ir.schemas.Parameter | ir.schemas.AttributeParameter]] = []
119+
tagged_kwargs: dict[
120+
str, tuple[Any, ir.schemas.Parameter | ir.schemas.AttributeParameter]
121+
] = {}
120122

121123
for i, param in enumerate(op_signature.params):
122-
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
124+
is_variadic = param.is_param() and param.variadic
123125

124126
if is_variadic:
125127
# Exhaust all remaining args
@@ -135,7 +137,7 @@ def tag_arguments_with_signature(
135137
if fill_defaults:
136138
default_value = param.default
137139
# Extract value from Attr object if it's an AttributeParameter
138-
if isinstance(param, _schemas.AttributeParameter):
140+
if param.is_attribute():
139141
default_value = param.default.value
140142
tagged_kwargs[param.name] = (default_value, param)
141143
elif param.required:
@@ -145,14 +147,14 @@ def tag_arguments_with_signature(
145147

146148

147149
def turn_to_kwargs_to_avoid_ordering(
148-
op_signature: _schemas.OpSignature,
150+
op_signature: ir.schemas.OpSignature,
149151
inputs: list[Any],
150152
attributes: dict[str, Any],
151153
) -> dict[str, Any]:
152154
"""Return the inputs and attributes to the order of the function signature."""
153155
for idx, param in enumerate(op_signature.params):
154156
if param.name not in attributes:
155-
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
157+
is_variadic = isinstance(param, ir.schemas.Parameter) and param.variadic
156158
if is_variadic:
157159
attributes[param.name] = inputs[idx:]
158160
elif inputs:

onnxscript/_internal/param_manipulation_test.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import parameterized
99

10+
from onnxscript import ir
1011
from onnxscript._internal import param_manipulation
11-
from onnxscript.ir import _schemas
1212

1313
TEST_INPUT = "TEST_INPUT"
1414

@@ -64,23 +64,23 @@ class TestSeparateInputAttributesFromArguments(unittest.TestCase):
6464
)
6565
def test_it_is_correct_on(self, _, args, kwargs, expected_c):
6666
# Create OpSignature with one input and two attributes
67-
type_constraint = _schemas.TypeConstraintParam.any_tensor("T")
68-
op_signature = _schemas.OpSignature(
67+
type_constraint = ir.schemas.TypeConstraintParam.any_tensor("T")
68+
op_signature = ir.schemas.OpSignature(
6969
domain="",
7070
name="TestOp",
7171
overload="",
7272
params=[
73-
_schemas.Parameter(
73+
ir.schemas.Parameter(
7474
name="a", type_constraint=type_constraint, required=True, variadic=False
7575
),
76-
_schemas.AttributeParameter(
77-
name="b", type=_schemas.ir.AttributeType.INT, required=True, default=None
76+
ir.schemas.AttributeParameter(
77+
name="b", type=ir.AttributeType.INT, required=True, default=None
7878
),
79-
_schemas.AttributeParameter(
79+
ir.schemas.AttributeParameter(
8080
name="c",
81-
type=_schemas.ir.AttributeType.FLOAT,
81+
type=ir.AttributeType.FLOAT,
8282
required=False,
83-
default=_schemas.ir.Attr("c", _schemas.ir.AttributeType.FLOAT, 100.0),
83+
default=ir.Attr("c", ir.AttributeType.FLOAT, 100.0),
8484
),
8585
],
8686
outputs=[],
@@ -113,23 +113,23 @@ def test_it_is_correct_on(self, _, args, kwargs, expected_c):
113113
]
114114
)
115115
def test_it_raises_on_extra_args(self, _, args, kwargs):
116-
type_constraint = _schemas.TypeConstraintParam.any_tensor("T")
117-
op_signature = _schemas.OpSignature(
116+
type_constraint = ir.schemas.TypeConstraintParam.any_tensor("T")
117+
op_signature = ir.schemas.OpSignature(
118118
domain="",
119119
name="TestOp",
120120
overload="",
121121
params=[
122-
_schemas.Parameter(
122+
ir.schemas.Parameter(
123123
name="a", type_constraint=type_constraint, required=True, variadic=False
124124
),
125-
_schemas.AttributeParameter(
126-
name="b", type=_schemas.ir.AttributeType.INT, required=True, default=None
125+
ir.schemas.AttributeParameter(
126+
name="b", type=ir.AttributeType.INT, required=True, default=None
127127
),
128-
_schemas.AttributeParameter(
128+
ir.schemas.AttributeParameter(
129129
name="c",
130-
type=_schemas.ir.AttributeType.FLOAT,
130+
type=ir.AttributeType.FLOAT,
131131
required=False,
132-
default=_schemas.ir.Attr("c", _schemas.ir.AttributeType.FLOAT, 100.0),
132+
default=ir.Attr("c", ir.AttributeType.FLOAT, 100.0),
133133
),
134134
],
135135
outputs=[],
@@ -150,23 +150,23 @@ def test_it_raises_on_extra_kwargs_when_not_allow_extra_kwargs(
150150
self,
151151
fill_defaults: bool,
152152
):
153-
type_constraint = _schemas.TypeConstraintParam.any_tensor("T")
154-
op_signature = _schemas.OpSignature(
153+
type_constraint = ir.schemas.TypeConstraintParam.any_tensor("T")
154+
op_signature = ir.schemas.OpSignature(
155155
domain="",
156156
name="TestOp",
157157
overload="",
158158
params=[
159-
_schemas.Parameter(
159+
ir.schemas.Parameter(
160160
name="a", type_constraint=type_constraint, required=True, variadic=False
161161
),
162-
_schemas.AttributeParameter(
163-
name="b", type=_schemas.ir.AttributeType.INT, required=True, default=None
162+
ir.schemas.AttributeParameter(
163+
name="b", type=ir.AttributeType.INT, required=True, default=None
164164
),
165-
_schemas.AttributeParameter(
165+
ir.schemas.AttributeParameter(
166166
name="c",
167-
type=_schemas.ir.AttributeType.FLOAT,
167+
type=ir.AttributeType.FLOAT,
168168
required=False,
169-
default=_schemas.ir.Attr("c", _schemas.ir.AttributeType.FLOAT, 100.0),
169+
default=ir.Attr("c", ir.AttributeType.FLOAT, 100.0),
170170
),
171171
],
172172
outputs=[],
@@ -190,23 +190,23 @@ def test_it_raises_on_extra_kwargs_when_not_allow_extra_kwargs(
190190
def test_it_does_not_fill_default_when_fill_defaults_is_false(
191191
self, allow_extra_kwargs: bool
192192
):
193-
type_constraint = _schemas.TypeConstraintParam.any_tensor("T")
194-
op_signature = _schemas.OpSignature(
193+
type_constraint = ir.schemas.TypeConstraintParam.any_tensor("T")
194+
op_signature = ir.schemas.OpSignature(
195195
domain="",
196196
name="TestOp",
197197
overload="",
198198
params=[
199-
_schemas.Parameter(
199+
ir.schemas.Parameter(
200200
name="a", type_constraint=type_constraint, required=True, variadic=False
201201
),
202-
_schemas.AttributeParameter(
203-
name="b", type=_schemas.ir.AttributeType.INT, required=True, default=None
202+
ir.schemas.AttributeParameter(
203+
name="b", type=ir.AttributeType.INT, required=True, default=None
204204
),
205-
_schemas.AttributeParameter(
205+
ir.schemas.AttributeParameter(
206206
name="c",
207-
type=_schemas.ir.AttributeType.FLOAT,
207+
type=ir.AttributeType.FLOAT,
208208
required=False,
209-
default=_schemas.ir.Attr("c", _schemas.ir.AttributeType.FLOAT, 100.0),
209+
default=ir.Attr("c", ir.AttributeType.FLOAT, 100.0),
210210
),
211211
],
212212
outputs=[],
@@ -234,23 +234,23 @@ def test_it_does_not_fill_default_when_fill_defaults_is_false(
234234
def test_it_raises_on_insufficient_args(
235235
self, fill_defaults: bool, allow_extra_kwargs: bool
236236
):
237-
type_constraint = _schemas.TypeConstraintParam.any_tensor("T")
238-
op_signature = _schemas.OpSignature(
237+
type_constraint = ir.schemas.TypeConstraintParam.any_tensor("T")
238+
op_signature = ir.schemas.OpSignature(
239239
domain="",
240240
name="TestOp",
241241
overload="",
242242
params=[
243-
_schemas.Parameter(
243+
ir.schemas.Parameter(
244244
name="a", type_constraint=type_constraint, required=True, variadic=False
245245
),
246-
_schemas.AttributeParameter(
247-
name="b", type=_schemas.ir.AttributeType.INT, required=True, default=None
246+
ir.schemas.AttributeParameter(
247+
name="b", type=ir.AttributeType.INT, required=True, default=None
248248
),
249-
_schemas.AttributeParameter(
249+
ir.schemas.AttributeParameter(
250250
name="c",
251-
type=_schemas.ir.AttributeType.FLOAT,
251+
type=ir.AttributeType.FLOAT,
252252
required=False,
253-
default=_schemas.ir.Attr("c", _schemas.ir.AttributeType.FLOAT, 100.0),
253+
default=ir.Attr("c", ir.AttributeType.FLOAT, 100.0),
254254
),
255255
],
256256
outputs=[],

0 commit comments

Comments
 (0)