@@ -421,14 +421,14 @@ def _convert_shape(shape, name):
421421 if isinstance (shape , np .ndarray ):
422422 shape = tape .initializer (ir .Tensor (shape , name = name ))
423423 elif isinstance (shape , (list , tuple )):
424- shape = ir .val (name , ir .Shape ( shape ) , ir .TensorType ( ir . DataType . INT64 ))
424+ shape = ir .val (name , ir .DataType . INT64 , ir .Shape ( shape ))
425425 tape .graph_like .inputs .append (shape )
426426 else :
427427 raise TypeError (f"Unsupported type { type (shape )} for shape." )
428428 return shape
429429
430- x = ir .val ("X" , ir .Shape ( input_shape ) , ir .TensorType ( ir . DataType . FLOAT ))
431- y = ir .val ("Y" , type = ir .TensorType ( ir . DataType .FLOAT ) )
430+ x = ir .val ("X" , ir .DataType . FLOAT , ir .Shape ( input_shape ))
431+ y = ir .val ("Y" , ir .DataType .FLOAT )
432432 tape = ir .tape .Tape (ir .Graph ([x ], [y ], nodes = [], opset_imports = {"" : 20 }))
433433
434434 # Build the graph.
@@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
554554class Flatten2ReshapeTest (unittest .TestCase ):
555555 @staticmethod
556556 def create_model (input_shape , axis = 1 ):
557- x = ir .val ("X" , ir .Shape ( input_shape ) , ir .TensorType ( ir . DataType . FLOAT ))
558- y = ir .val ("Y" , type = ir .TensorType ( ir . DataType .FLOAT ) )
557+ x = ir .val ("X" , ir .DataType . FLOAT , ir .Shape ( input_shape ))
558+ y = ir .val ("Y" , ir .DataType .FLOAT )
559559 tape = ir .tape .Tape (ir .Graph ([x ], [y ], nodes = [], opset_imports = {"" : 20 }))
560560
561561 # Build the graph.
0 commit comments