77
88import parameterized
99
10+ from onnxscript import ir
1011from onnxscript ._internal import param_manipulation
11- from onnxscript .ir import _schemas
1212
1313TEST_INPUT = "TEST_INPUT"
1414
@@ -64,23 +64,23 @@ class TestSeparateInputAttributesFromArguments(unittest.TestCase):
6464 )
6565 def test_it_is_correct_on (self , _ , args , kwargs , expected_c ):
6666 # Create OpSignature with one input and two attributes
67- type_constraint = _schemas .TypeConstraintParam .any_tensor ("T" )
68- op_signature = _schemas .OpSignature (
67+ type_constraint = ir . schemas .TypeConstraintParam .any_tensor ("T" )
68+ op_signature = ir . schemas .OpSignature (
6969 domain = "" ,
7070 name = "TestOp" ,
7171 overload = "" ,
7272 params = [
73- _schemas .Parameter (
73+ ir . schemas .Parameter (
7474 name = "a" , type_constraint = type_constraint , required = True , variadic = False
7575 ),
76- _schemas .AttributeParameter (
77- name = "b" , type = _schemas . ir .AttributeType .INT , required = True , default = None
76+ ir . schemas .AttributeParameter (
77+ name = "b" , type = ir .AttributeType .INT , required = True , default = None
7878 ),
79- _schemas .AttributeParameter (
79+ ir . schemas .AttributeParameter (
8080 name = "c" ,
81- type = _schemas . ir .AttributeType .FLOAT ,
81+ type = ir .AttributeType .FLOAT ,
8282 required = False ,
83- default = _schemas . ir .Attr ("c" , _schemas . ir .AttributeType .FLOAT , 100.0 ),
83+ default = ir .Attr ("c" , ir .AttributeType .FLOAT , 100.0 ),
8484 ),
8585 ],
8686 outputs = [],
@@ -113,23 +113,23 @@ def test_it_is_correct_on(self, _, args, kwargs, expected_c):
113113 ]
114114 )
115115 def test_it_raises_on_extra_args (self , _ , args , kwargs ):
116- type_constraint = _schemas .TypeConstraintParam .any_tensor ("T" )
117- op_signature = _schemas .OpSignature (
116+ type_constraint = ir . schemas .TypeConstraintParam .any_tensor ("T" )
117+ op_signature = ir . schemas .OpSignature (
118118 domain = "" ,
119119 name = "TestOp" ,
120120 overload = "" ,
121121 params = [
122- _schemas .Parameter (
122+ ir . schemas .Parameter (
123123 name = "a" , type_constraint = type_constraint , required = True , variadic = False
124124 ),
125- _schemas .AttributeParameter (
126- name = "b" , type = _schemas . ir .AttributeType .INT , required = True , default = None
125+ ir . schemas .AttributeParameter (
126+ name = "b" , type = ir .AttributeType .INT , required = True , default = None
127127 ),
128- _schemas .AttributeParameter (
128+ ir . schemas .AttributeParameter (
129129 name = "c" ,
130- type = _schemas . ir .AttributeType .FLOAT ,
130+ type = ir .AttributeType .FLOAT ,
131131 required = False ,
132- default = _schemas . ir .Attr ("c" , _schemas . ir .AttributeType .FLOAT , 100.0 ),
132+ default = ir .Attr ("c" , ir .AttributeType .FLOAT , 100.0 ),
133133 ),
134134 ],
135135 outputs = [],
@@ -150,23 +150,23 @@ def test_it_raises_on_extra_kwargs_when_not_allow_extra_kwargs(
150150 self ,
151151 fill_defaults : bool ,
152152 ):
153- type_constraint = _schemas .TypeConstraintParam .any_tensor ("T" )
154- op_signature = _schemas .OpSignature (
153+ type_constraint = ir . schemas .TypeConstraintParam .any_tensor ("T" )
154+ op_signature = ir . schemas .OpSignature (
155155 domain = "" ,
156156 name = "TestOp" ,
157157 overload = "" ,
158158 params = [
159- _schemas .Parameter (
159+ ir . schemas .Parameter (
160160 name = "a" , type_constraint = type_constraint , required = True , variadic = False
161161 ),
162- _schemas .AttributeParameter (
163- name = "b" , type = _schemas . ir .AttributeType .INT , required = True , default = None
162+ ir . schemas .AttributeParameter (
163+ name = "b" , type = ir .AttributeType .INT , required = True , default = None
164164 ),
165- _schemas .AttributeParameter (
165+ ir . schemas .AttributeParameter (
166166 name = "c" ,
167- type = _schemas . ir .AttributeType .FLOAT ,
167+ type = ir .AttributeType .FLOAT ,
168168 required = False ,
169- default = _schemas . ir .Attr ("c" , _schemas . ir .AttributeType .FLOAT , 100.0 ),
169+ default = ir .Attr ("c" , ir .AttributeType .FLOAT , 100.0 ),
170170 ),
171171 ],
172172 outputs = [],
@@ -190,23 +190,23 @@ def test_it_raises_on_extra_kwargs_when_not_allow_extra_kwargs(
190190 def test_it_does_not_fill_default_when_fill_defaults_is_false (
191191 self , allow_extra_kwargs : bool
192192 ):
193- type_constraint = _schemas .TypeConstraintParam .any_tensor ("T" )
194- op_signature = _schemas .OpSignature (
193+ type_constraint = ir . schemas .TypeConstraintParam .any_tensor ("T" )
194+ op_signature = ir . schemas .OpSignature (
195195 domain = "" ,
196196 name = "TestOp" ,
197197 overload = "" ,
198198 params = [
199- _schemas .Parameter (
199+ ir . schemas .Parameter (
200200 name = "a" , type_constraint = type_constraint , required = True , variadic = False
201201 ),
202- _schemas .AttributeParameter (
203- name = "b" , type = _schemas . ir .AttributeType .INT , required = True , default = None
202+ ir . schemas .AttributeParameter (
203+ name = "b" , type = ir .AttributeType .INT , required = True , default = None
204204 ),
205- _schemas .AttributeParameter (
205+ ir . schemas .AttributeParameter (
206206 name = "c" ,
207- type = _schemas . ir .AttributeType .FLOAT ,
207+ type = ir .AttributeType .FLOAT ,
208208 required = False ,
209- default = _schemas . ir .Attr ("c" , _schemas . ir .AttributeType .FLOAT , 100.0 ),
209+ default = ir .Attr ("c" , ir .AttributeType .FLOAT , 100.0 ),
210210 ),
211211 ],
212212 outputs = [],
@@ -234,23 +234,23 @@ def test_it_does_not_fill_default_when_fill_defaults_is_false(
234234 def test_it_raises_on_insufficient_args (
235235 self , fill_defaults : bool , allow_extra_kwargs : bool
236236 ):
237- type_constraint = _schemas .TypeConstraintParam .any_tensor ("T" )
238- op_signature = _schemas .OpSignature (
237+ type_constraint = ir . schemas .TypeConstraintParam .any_tensor ("T" )
238+ op_signature = ir . schemas .OpSignature (
239239 domain = "" ,
240240 name = "TestOp" ,
241241 overload = "" ,
242242 params = [
243- _schemas .Parameter (
243+ ir . schemas .Parameter (
244244 name = "a" , type_constraint = type_constraint , required = True , variadic = False
245245 ),
246- _schemas .AttributeParameter (
247- name = "b" , type = _schemas . ir .AttributeType .INT , required = True , default = None
246+ ir . schemas .AttributeParameter (
247+ name = "b" , type = ir .AttributeType .INT , required = True , default = None
248248 ),
249- _schemas .AttributeParameter (
249+ ir . schemas .AttributeParameter (
250250 name = "c" ,
251- type = _schemas . ir .AttributeType .FLOAT ,
251+ type = ir .AttributeType .FLOAT ,
252252 required = False ,
253- default = _schemas . ir .Attr ("c" , _schemas . ir .AttributeType .FLOAT , 100.0 ),
253+ default = ir .Attr ("c" , ir .AttributeType .FLOAT , 100.0 ),
254254 ),
255255 ],
256256 outputs = [],
0 commit comments