Skip to content

Commit 8f55cdb

Browse files
committed
Faster insertion of operator cast
1 parent 0110037 commit 8f55cdb

File tree

4 files changed

+30
-35
lines changed

4 files changed

+30
-35
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,9 @@ def version_11(cls, ctx, node, **kwargs):
545545
shapes=shapes, dtypes=dtypes, domain=constants.ONNX_DOMAIN, attr={'direction': direction})
546546

547547
if node.maybe_cast_input([supported, supported], type_map):
548-
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
549-
name=utils.make_name(node.name) + "_castback")
550-
cast_back_node.set_attr("to", dtypes[0])
548+
cast_back_node = ctx.insert_new_node_on_output(
549+
"Cast", node.output[0], name=utils.make_name(node.name) + "_castback",
550+
to=dtypes[0])
551551
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
552552
ctx.copy_shape(node.name, cast_back_node.output[0])
553553

tf2onnx/onnx_opset/nn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,13 @@ def version_1(cls, ctx, node, **kwargs):
637637
origin_dtype = ctx.get_dtype(node.output[0])
638638
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
639639
onnx_pb.TensorProto.DOUBLE]:
640-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
641-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
640+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
642641
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
643642
ctx.copy_shape(node.name, cast_node.output[0])
644643

645644
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
646-
name=utils.make_name(node.name) + "_castback")
647-
cast_back_node.set_attr("to", origin_dtype)
645+
name=utils.make_name(node.name) + "_castback",
646+
to=origin_dtype)
648647
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
649648
ctx.copy_shape(node.name, cast_back_node.output[0])
650649

@@ -667,14 +666,13 @@ def version_11(cls, ctx, node, **kwargs):
667666
origin_dtype = ctx.get_dtype(node.output[0])
668667
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
669668
TensorProto.INT32, TensorProto.INT64]:
670-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
671-
cast_node.set_attr("to", TensorProto.FLOAT)
669+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.FLOAT)
672670
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
673671
ctx.copy_shape(node.name, cast_node.output[0])
674672

675673
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
676-
name=utils.make_name(node.name) + "_castback")
677-
cast_back_node.set_attr("to", origin_dtype)
674+
name=utils.make_name(node.name) + "_castback",
675+
to=origin_dtype)
678676
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
679677
ctx.copy_shape(node.name, cast_back_node.output[0])
680678

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
3131
"""cast int32 shape into int64 shape."""
3232
name = node.input[input_number]
3333

34-
cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
35-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
34+
cast_node = ctx.insert_new_node_on_input(node, "Cast", name, to=onnx_pb.TensorProto.INT64)
3635
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
3736
ctx.copy_shape(name, cast_node.output[0])
3837

@@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
4645
output_name = node.output[0]
4746
# cast each inputs to float
4847
for i, inp in enumerate(node.inputs):
49-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
50-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
48+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i],
49+
to=onnx_pb.TensorProto.FLOAT)
5150
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
5251
next_nodes = ctx.find_output_consumers(node.output[0])
5352
# cast output back to dtype unless the next op is a cast
5453
if next_nodes[0].type != "Cast":
55-
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name())
56-
output_cast.set_attr("to", dtype)
54+
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
55+
to=dtype)
5756
ctx.set_dtype(output_cast.output[0], dtype)
5857
ctx.copy_shape(output_name, output_cast.output[0])
5958

@@ -157,15 +156,14 @@ def version_5(cls, ctx, node, **kwargs):
157156
return
158157

159158
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
160-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
161-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
159+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
162160
ctx.copy_shape(node.output[0], input_cast.output[0])
163161

164162
# if the next node is already a cast we don't need to insert another one
165163
next_nodes = ctx.find_output_consumers(node.output[0])
166164
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
167-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name())
168-
output_cast.set_attr("to", dtype)
165+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name(),
166+
to=dtype)
169167
ctx.set_dtype(output_cast.output[0], dtype)
170168
ctx.copy_shape(node.output[0], output_cast.output[0])
171169

@@ -739,16 +737,17 @@ def version_1(cls, ctx, node, **kwargs):
739737
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
740738
# override the previous cast
741739
cast_node = node.inputs[0]
740+
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
742741
else:
743-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
742+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0],
743+
to=onnx_pb.TensorProto.FLOAT)
744744
nodes.insert(0, cast_node)
745-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
746745
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
747746
ctx.copy_shape(node.input[0], cast_node.output[0])
748747
# undo the cast afer slice
749748
name = utils.make_name(node.name)
750-
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
751-
cast_node.set_attr("to", input_dtype)
749+
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name,
750+
to=input_dtype)
752751
ctx.set_dtype(cast_node.output[0], input_dtype)
753752
ctx.copy_shape(node.output[0], cast_node.output[0])
754753
nodes.append(cast_node)
@@ -1180,8 +1179,7 @@ def version_1(cls, ctx, node, **kwargs):
11801179
if dtype == onnx_pb.TensorProto.INT64:
11811180
return
11821181
op_name = utils.make_name(node.name)
1183-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name)
1184-
output_cast.set_attr("to", dtype)
1182+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, to=dtype)
11851183
ctx.set_dtype(output_cast.output[0], dtype)
11861184
ctx.copy_shape(node.output[0], output_cast.output[0])
11871185

@@ -1555,8 +1553,7 @@ def version_8(cls, ctx, node, **kwargs):
15551553

15561554
seq_len_dtype = ctx.get_dtype(node.input[1])
15571555
if seq_len_dtype != onnx_pb.TensorProto.INT64:
1558-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
1559-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
1556+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
15601557
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
15611558
ctx.copy_shape(node.input[1], cast_node.output[0])
15621559

@@ -1762,8 +1759,8 @@ def version_11(cls, ctx, node, **kwargs):
17621759
# cast to int64 if needed
17631760
if dtypes[1] != onnx_pb.TensorProto.UINT64:
17641761
cast_node = ctx.insert_new_node_on_output("Cast", node.output[1],
1765-
name=utils.make_name(node.name) + "_cast")
1766-
cast_node.set_attr("to", dtypes[1])
1762+
name=utils.make_name(node.name) + "_cast",
1763+
to=dtypes[1])
17671764
ctx.set_dtype(cast_node.output[0], dtypes[1])
17681765
ctx.copy_shape(node.output[1], cast_node.output[0])
17691766
# FIXME: the indices in onnx are not the same as in tensorflow.

tf2onnx/tfonnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
160160
input_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
161161
g.set_dtype(input_name, onnx_pb.TensorProto.FLOAT)
162162
else:
163-
cast_node = g.insert_new_node_on_input(op, "Cast", input_name)
164-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
163+
cast_node = g.insert_new_node_on_input(op, "Cast", input_name,
164+
to=onnx_pb.TensorProto.FLOAT)
165165
g.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
166166
g.copy_shape(input_name, cast_node.output[0])
167167
cast_inserted.append(cast_node)
@@ -171,8 +171,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
171171
name = utils.make_name(op.name)
172172
logger.debug("insert cast back for node %s on output %s [dtype=%s]", op.name, output_name,
173173
output_dtype)
174-
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name)
175-
output_cast.set_attr("to", output_dtype)
174+
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name,
175+
to=output_dtype)
176176
g.set_dtype(output_cast.output[0], output_dtype)
177177
g.copy_shape(output_name, output_cast.output[0])
178178
cast_inserted.append(output_cast)

0 commit comments

Comments
 (0)