Skip to content

Commit 4a9d8d1

Browse files
authored
Merge pull request #744 from jignparm/jignparm/bert_model_cleanup
Remove unnecessary ops when converting BERT model
2 parents 55d8ea8 + 9134eb7 commit 4a9d8d1

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -722,23 +722,24 @@ def version_1(cls, ctx, node, **kwargs):
722722

723723
# onnx slice as of opset 7 does only take float tensors ... cast if needed
724724
input_dtype = ctx.get_dtype(node.input[0])
725-
if input_dtype != onnx_pb.TensorProto.FLOAT:
726-
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
727-
# override the previous cast
728-
cast_node = node.inputs[0]
729-
else:
730-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
731-
nodes.insert(0, cast_node)
732-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
733-
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
734-
ctx.copy_shape(node.input[0], cast_node.output[0])
735-
# undo the cast afer slice
736-
name = utils.make_name(node.name)
737-
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
738-
cast_node.set_attr("to", input_dtype)
739-
ctx.set_dtype(cast_node.output[0], input_dtype)
740-
ctx.copy_shape(node.output[0], cast_node.output[0])
741-
nodes.append(cast_node)
725+
if ctx.opset < 9:
726+
if input_dtype != onnx_pb.TensorProto.FLOAT:
727+
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
728+
# override the previous cast
729+
cast_node = node.inputs[0]
730+
else:
731+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
732+
nodes.insert(0, cast_node)
733+
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
734+
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
735+
ctx.copy_shape(node.input[0], cast_node.output[0])
736+
# undo the cast afer slice
737+
name = utils.make_name(node.name)
738+
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
739+
cast_node.set_attr("to", input_dtype)
740+
ctx.set_dtype(cast_node.output[0], input_dtype)
741+
ctx.copy_shape(node.output[0], cast_node.output[0])
742+
nodes.append(cast_node)
742743

743744
@classmethod
744745
def version_10(cls, ctx, node, **kwargs):

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _optimize(self, graph):
3636
def _optimize_at_current_graph_level(self, g):
3737
for optype, handler in _func_map.items():
3838
# candidate nodes for removal/optimization
39-
nodes = [n for n in g.get_nodes() if n.type == optype]
39+
nodes = [n for n in g.get_nodes() if n.type in optype]
4040

4141
# topological sort of candidates
4242
# simplifying assumption for back-to-back-optimizer is
@@ -51,8 +51,7 @@ def _optimize_at_current_graph_level(self, g):
5151
# q = starting nodes with no dependencies
5252
q = list(set(consumer_node_ids.keys()) - has_dependencies)
5353
while q:
54-
nodeid = q[0]
55-
q.remove(nodeid)
54+
nodeid = q.pop(0)
5655
node = g.get_node_by_output(nodeid, False)
5756
consumer_nodes = consumer_node_ids[nodeid]
5857

@@ -72,6 +71,7 @@ def _optimize_at_current_graph_level(self, g):
7271
@staticmethod
7372
@_register_func("Cast")
7473
def _optimize_cast(g, node, consumer_nodes):
74+
"""remove long chains of cast ops"""
7575
q2 = []
7676
type1 = node.get_attr('to').i
7777
type1_name = ONNX_DTYPE_NAMES[type1] if type1 in ONNX_DTYPE_NAMES else ''
@@ -124,6 +124,7 @@ def _optimize_cast(g, node, consumer_nodes):
124124
@staticmethod
125125
@_register_func("Transpose")
126126
def _optimize_transpose(g, node, consumer_nodes):
127+
"""remove long chains of transpose ops"""
127128
t1 = list(node.get_attr('perm').ints)
128129
q2 = []
129130
for node2 in consumer_nodes:
@@ -146,3 +147,33 @@ def _optimize_transpose(g, node, consumer_nodes):
146147
q2.append(node2.output[0])
147148
g.remove_node(node.name)
148149
return q2
150+
151+
@staticmethod
152+
@_register_func(('Squeeze', 'Unsqueeze'))
153+
def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
154+
"""remove pairs of squeeze-unsqueeze nodes"""
155+
156+
if node.type != 'Squeeze' or len(consumer_nodes) != 1:
157+
# no need to return any value, since not removing long chain of nodes
158+
return []
159+
160+
node2 = consumer_nodes[0]
161+
if node2.type != 'Unsqueeze':
162+
return []
163+
164+
axis1 = node.get_attr('axes').ints
165+
axis2 = node2.get_attr('axes').ints
166+
167+
# if squeeze followed by unsqueeze is on diff axes, skip
168+
if axis1 != axis2:
169+
return []
170+
171+
# if unsqueeze output is graph output, skip
172+
if set(node2.output) & set(g.outputs):
173+
return []
174+
175+
node2_consumers = g.find_output_consumers(node2.output[0])
176+
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
177+
g.remove_node(node.name)
178+
g.remove_node(node2.name)
179+
return []

0 commit comments

Comments
 (0)