@@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
31
31
"""cast int32 shape into int64 shape."""
32
32
name = node .input [input_number ]
33
33
34
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , name )
35
- cast_node .set_attr ("to" , onnx_pb .TensorProto .INT64 )
34
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , name , to = onnx_pb .TensorProto .INT64 )
36
35
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .INT64 )
37
36
ctx .copy_shape (name , cast_node .output [0 ])
38
37
@@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
46
45
output_name = node .output [0 ]
47
46
# cast each inputs to float
48
47
for i , inp in enumerate (node .inputs ):
49
- input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [i ])
50
- input_cast . set_attr ( "to" , onnx_pb .TensorProto .FLOAT )
48
+ input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [i ],
49
+ to = onnx_pb .TensorProto .FLOAT )
51
50
ctx .set_dtype (input_cast .output [0 ], onnx_pb .TensorProto .FLOAT )
52
51
next_nodes = ctx .find_output_consumers (node .output [0 ])
53
52
# cast output back to dtype unless the next op is a cast
54
53
if next_nodes [0 ].type != "Cast" :
55
- output_cast = ctx .insert_new_node_on_output ("Cast" , output_name , name = node .child_name ())
56
- output_cast . set_attr ( "to" , dtype )
54
+ output_cast = ctx .insert_new_node_on_output ("Cast" , output_name , name = node .child_name (),
55
+ to = dtype )
57
56
ctx .set_dtype (output_cast .output [0 ], dtype )
58
57
ctx .copy_shape (output_name , output_cast .output [0 ])
59
58
@@ -157,15 +156,14 @@ def version_5(cls, ctx, node, **kwargs):
157
156
return
158
157
159
158
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
160
- input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
161
- input_cast .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
159
+ input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ], to = onnx_pb .TensorProto .FLOAT )
162
160
ctx .copy_shape (node .output [0 ], input_cast .output [0 ])
163
161
164
162
# if the next node is already a cast we don't need to insert another one
165
163
next_nodes = ctx .find_output_consumers (node .output [0 ])
166
164
if len (next_nodes ) != 1 or next_nodes [0 ].type != "Cast" :
167
- output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = node .child_name ())
168
- output_cast . set_attr ( "to" , dtype )
165
+ output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = node .child_name (),
166
+ to = dtype )
169
167
ctx .set_dtype (output_cast .output [0 ], dtype )
170
168
ctx .copy_shape (node .output [0 ], output_cast .output [0 ])
171
169
@@ -739,16 +737,17 @@ def version_1(cls, ctx, node, **kwargs):
739
737
if node .inputs [0 ].type == "Cast" and len (ctx .find_output_consumers (node .inputs [0 ].output [0 ])) == 1 :
740
738
# override the previous cast
741
739
cast_node = node .inputs [0 ]
740
+ cast_node .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
742
741
else :
743
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
742
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ],
743
+ to = onnx_pb .TensorProto .FLOAT )
744
744
nodes .insert (0 , cast_node )
745
- cast_node .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
746
745
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .FLOAT )
747
746
ctx .copy_shape (node .input [0 ], cast_node .output [0 ])
748
747
# undo the cast afer slice
749
748
name = utils .make_name (node .name )
750
- cast_node = ctx .insert_new_node_on_output ("Cast" , nodes [- 1 ].output [0 ], name )
751
- cast_node . set_attr ( "to" , input_dtype )
749
+ cast_node = ctx .insert_new_node_on_output ("Cast" , nodes [- 1 ].output [0 ], name ,
750
+ to = input_dtype )
752
751
ctx .set_dtype (cast_node .output [0 ], input_dtype )
753
752
ctx .copy_shape (node .output [0 ], cast_node .output [0 ])
754
753
nodes .append (cast_node )
@@ -1180,8 +1179,7 @@ def version_1(cls, ctx, node, **kwargs):
1180
1179
if dtype == onnx_pb .TensorProto .INT64 :
1181
1180
return
1182
1181
op_name = utils .make_name (node .name )
1183
- output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name )
1184
- output_cast .set_attr ("to" , dtype )
1182
+ output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name , to = dtype )
1185
1183
ctx .set_dtype (output_cast .output [0 ], dtype )
1186
1184
ctx .copy_shape (node .output [0 ], output_cast .output [0 ])
1187
1185
@@ -1555,8 +1553,7 @@ def version_8(cls, ctx, node, **kwargs):
1555
1553
1556
1554
seq_len_dtype = ctx .get_dtype (node .input [1 ])
1557
1555
if seq_len_dtype != onnx_pb .TensorProto .INT64 :
1558
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ])
1559
- cast_node .set_attr ("to" , onnx_pb .TensorProto .INT64 )
1556
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
1560
1557
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .INT64 )
1561
1558
ctx .copy_shape (node .input [1 ], cast_node .output [0 ])
1562
1559
@@ -1762,8 +1759,8 @@ def version_11(cls, ctx, node, **kwargs):
1762
1759
# cast to int64 if needed
1763
1760
if dtypes [1 ] != onnx_pb .TensorProto .UINT64 :
1764
1761
cast_node = ctx .insert_new_node_on_output ("Cast" , node .output [1 ],
1765
- name = utils .make_name (node .name ) + "_cast" )
1766
- cast_node . set_attr ( "to" , dtypes [1 ])
1762
+ name = utils .make_name (node .name ) + "_cast" ,
1763
+ to = dtypes [1 ])
1767
1764
ctx .set_dtype (cast_node .output [0 ], dtypes [1 ])
1768
1765
ctx .copy_shape (node .output [1 ], cast_node .output [0 ])
1769
1766
# FIXME: the indices in onnx are not the same as in tensorflow.
0 commit comments