@@ -36,7 +36,7 @@ def _optimize(self, graph):
36
36
def _optimize_at_current_graph_level (self , g ):
37
37
for optype , handler in _func_map .items ():
38
38
# candidate nodes for removal/optimization
39
- nodes = [n for n in g .get_nodes () if n .type == optype ]
39
+ nodes = [n for n in g .get_nodes () if n .type in optype ]
40
40
41
41
# topological sort of candidates
42
42
# simplifying assumption for back-to-back-optimizer is
@@ -51,8 +51,7 @@ def _optimize_at_current_graph_level(self, g):
51
51
# q = starting nodes with no dependencies
52
52
q = list (set (consumer_node_ids .keys ()) - has_dependencies )
53
53
while q :
54
- nodeid = q [0 ]
55
- q .remove (nodeid )
54
+ nodeid = q .pop (0 )
56
55
node = g .get_node_by_output (nodeid , False )
57
56
consumer_nodes = consumer_node_ids [nodeid ]
58
57
@@ -72,6 +71,7 @@ def _optimize_at_current_graph_level(self, g):
72
71
@staticmethod
73
72
@_register_func ("Cast" )
74
73
def _optimize_cast (g , node , consumer_nodes ):
74
+ """remove long chains of cast ops"""
75
75
q2 = []
76
76
type1 = node .get_attr ('to' ).i
77
77
type1_name = ONNX_DTYPE_NAMES [type1 ] if type1 in ONNX_DTYPE_NAMES else ''
@@ -124,6 +124,7 @@ def _optimize_cast(g, node, consumer_nodes):
124
124
@staticmethod
125
125
@_register_func ("Transpose" )
126
126
def _optimize_transpose (g , node , consumer_nodes ):
127
+ """remove long chains of transpose ops"""
127
128
t1 = list (node .get_attr ('perm' ).ints )
128
129
q2 = []
129
130
for node2 in consumer_nodes :
@@ -146,3 +147,33 @@ def _optimize_transpose(g, node, consumer_nodes):
146
147
q2 .append (node2 .output [0 ])
147
148
g .remove_node (node .name )
148
149
return q2
150
+
151
+ @staticmethod
152
+ @_register_func (('Squeeze' , 'Unsqueeze' ))
153
+ def _optimize_squeeze_unsqueeze (g , node , consumer_nodes ):
154
+ """remove pairs of squeeze-unsqueeze nodes"""
155
+
156
+ if node .type != 'Squeeze' or len (consumer_nodes ) != 1 :
157
+ # no need to return any value, since not removing long chain of nodes
158
+ return []
159
+
160
+ node2 = consumer_nodes [0 ]
161
+ if node2 .type != 'Unsqueeze' :
162
+ return []
163
+
164
+ axis1 = node .get_attr ('axes' ).ints
165
+ axis2 = node2 .get_attr ('axes' ).ints
166
+
167
+ # if squeeze followed by unsqueeze is on diff axes, skip
168
+ if axis1 != axis2 :
169
+ return []
170
+
171
+ # if unsqueeze output is graph output, skip
172
+ if set (node2 .output ) & set (g .outputs ):
173
+ return []
174
+
175
+ node2_consumers = g .find_output_consumers (node2 .output [0 ])
176
+ g .replace_all_inputs (node2_consumers , node2 .output [0 ], node .input [0 ])
177
+ g .remove_node (node .name )
178
+ g .remove_node (node2 .name )
179
+ return []
0 commit comments