Skip to content

Commit a3166e2

Browse files
committed
Fix tests
Signed-off-by: Justin Chu <[email protected]>
1 parent 7ded7d1 commit a3166e2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

onnxscript/rewriter/rules/common/_basic_rules_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,14 +421,14 @@ def _convert_shape(shape, name):
421421
if isinstance(shape, np.ndarray):
422422
shape = tape.initializer(ir.Tensor(shape, name=name))
423423
elif isinstance(shape, (list, tuple)):
424-
shape = ir.val(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
424+
shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape))
425425
tape.graph_like.inputs.append(shape)
426426
else:
427427
raise TypeError(f"Unsupported type {type(shape)} for shape.")
428428
return shape
429429

430-
x = ir.val("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
431-
y = ir.val("Y", type=ir.TensorType(ir.DataType.FLOAT))
430+
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
431+
y = ir.val("Y", ir.DataType.FLOAT)
432432
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
433433

434434
# Build the graph.
@@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
554554
class Flatten2ReshapeTest(unittest.TestCase):
555555
@staticmethod
556556
def create_model(input_shape, axis=1):
557-
x = ir.val("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
558-
y = ir.val("Y", type=ir.TensorType(ir.DataType.FLOAT))
557+
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
558+
y = ir.val("Y", ir.DataType.FLOAT)
559559
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
560560

561561
# Build the graph.

0 commit comments

Comments
 (0)