@@ -55,7 +55,7 @@ def nodes(self):
55
55
56
56
def pre_optimize_action (self ):
57
57
# make Reshape into a const, which then can be fused into Conv's weight for mobilenet_v1_75_192
58
- self ._output_names = [name . split ( ":" )[ 0 ] for name in self ._g .outputs ]
58
+ self ._output_names = [self . _g . get_node_by_output ( out ). name for out in self ._g .outputs ]
59
59
ops = self .nodes
60
60
constable_reshape_ops = [n for n in ops
61
61
if (n .type == "Reshape"
@@ -179,6 +179,8 @@ def _optimize_at_current_graph_level(self, graph):
179
179
def _initialize_handlers (self ):
180
180
self ._handler_map = {
181
181
"Add" : self ._add_handler ,
182
+ "ArgMax" : self ._arg_min_max_handler ,
183
+ "ArgMin" : self ._arg_min_max_handler ,
182
184
"Cast" : self ._simple_through_handler ,
183
185
"Clip" : self ._simple_through_handler ,
184
186
"Concat" : self ._concat_handler ,
@@ -192,8 +194,14 @@ def _initialize_handlers(self):
192
194
"Mul" : self ._mul_handler ,
193
195
"Pad" : self ._pad_handler ,
194
196
"Reciprocal" : self ._simple_through_handler ,
195
- "ReduceMean" : self ._reducemean_handler ,
197
+ "ReduceLogSum" : self ._reduce_handler ,
198
+ "ReduceLogSumExp" : self ._reduce_handler ,
199
+ "ReduceMax" : self ._reduce_handler ,
200
+ "ReduceMean" : self ._reduce_handler ,
201
+ "ReduceMin" : self ._reduce_handler ,
202
+ "ReduceProd" : self ._reduce_handler ,
196
203
"ReduceSum" : self ._reducesum_handler ,
204
+ "ReduceSumSquare" : self ._reduce_handler ,
197
205
"Relu" : self ._simple_through_handler ,
198
206
"Shape" : self ._shape_handler ,
199
207
"Sigmoid" : self ._simple_through_handler ,
@@ -258,7 +266,7 @@ def _get_input_index_for_trans(self, node, trans):
258
266
return input_index
259
267
260
268
# the assumption is: both node and trans have only 1 output
261
- def _switch_transpose_and_node (self , node , trans ):
269
+ def _switch_transpose_and_node (self , node , trans , update_shape = True ):
262
270
if not self ._nodes_has_single_consumer_node ([trans ]):
263
271
return False
264
272
@@ -271,7 +279,7 @@ def _switch_transpose_and_node(self, node, trans):
271
279
# need to transpose node shape in backward direction as well after switch
272
280
# otherwise, reshape added in post_optimize_action may not work correctly
273
281
shape = self ._g .get_shape (node .output [0 ])
274
- if shape :
282
+ if update_shape and shape :
275
283
# only nhwc transpose can reach here
276
284
new_shape = [shape [i ] for i in NHWC_TO_NCHW ]
277
285
self ._g .set_shape (node .output [0 ], new_shape )
@@ -700,31 +708,49 @@ def _pad_handler(self, trans, node):
700
708
self ._g .replace_input (node , node .input [1 ], new_pads .output [0 ], 1 )
701
709
return self ._switch_transpose_and_node (node , trans )
702
710
703
- def _reducemean_handler (self , trans , node ):
704
- axes = node .get_attr ("axes" ).ints
705
- keepdims = node .get_attr ("keepdims" )
711
+ def _arg_min_max_handler (self , trans , node ):
712
+ axis = node .get_attr_value ("axis" , 0 )
713
+ node .set_attr ("axes" , [axis ])
714
+ result = self ._reduce_handler (trans , node )
715
+ new_axis = node .get_attr_value ("axes" )[0 ]
716
+ node .set_attr ("axis" , new_axis )
717
+ del node .attr ["axes" ]
718
+ return result
719
+
720
+ def _reduce_handler (self , trans , node ):
721
+ keepdims = node .get_attr_value ("keepdims" , 1 )
706
722
trans_rank = get_transpose_rank (trans )
707
- # make sure keepdims is 1, then we can do the swap, otherwise, please don't, because
708
- # once keepdims is not set, original dims are lost, so transpose back won't work well.
709
- # by default, if keepdims is not specified, it is 1
710
- if axes == list (range (1 , trans_rank - 1 )) and ((keepdims and keepdims .i == 1 ) or (not keepdims )):
711
- node .set_attr ("axes" , list (range (2 , trans_rank )))
712
- return self ._switch_transpose_and_node (node , trans )
713
- return False
723
+ axes = node .get_attr_value ("axes" , list (range (trans_rank )))
724
+ perm = trans .get_attr ("perm" ).ints
725
+ axes = [a + trans_rank if a < 0 else a for a in axes ]
726
+ new_axes = [perm [a ] for a in axes ]
727
+ update_shape = keepdims == 1
728
+ shape = self ._g .get_shape (node .output [0 ])
729
+ if not self ._switch_transpose_and_node (node , trans , update_shape ):
730
+ return False
731
+ node .set_attr ("axes" , new_axes )
732
+ if keepdims == 0 :
733
+ remaining_axes = []
734
+ j = 0
735
+ for i in range (trans_rank ):
736
+ if i in new_axes :
737
+ remaining_axes .append (None )
738
+ else :
739
+ remaining_axes .append (j )
740
+ j += 1
741
+ new_perm = [remaining_axes [p ] for p in perm if remaining_axes [p ] is not None ]
742
+ if shape :
743
+ new_shape = [shape [new_perm .index (i )] for i in range (len (new_perm ))]
744
+ self ._g .set_shape (node .output [0 ], new_shape )
745
+ trans .set_attr ("perm" , new_perm )
746
+ return True
714
747
715
748
def _reducesum_handler (self , trans , node ):
716
749
keepdims = node .get_attr ("keepdims" )
717
- # make sure keepdims is 1, then we can do the swap, otherwise, please don't, because
718
- # once keepdims is not set, original dims are lost, so transpose back won't work well.
719
- # by default, if keepdims is not specified, it is 1
750
+ if self ._g .opset <= 12 :
751
+ return self ._reduce_handler (trans , node )
720
752
if keepdims and keepdims .i == 0 :
721
753
return False
722
- if self ._g .opset <= 12 :
723
- axes = node .get_attr ("axes" ).ints
724
- perm = trans .get_attr ('perm' ).ints
725
- new_axes = [perm [axis ] for axis in axes ]
726
- node .set_attr ("axes" , new_axes )
727
- return self ._switch_transpose_and_node (node , trans )
728
754
if node .inputs [1 ].is_const ():
729
755
axes = node .inputs [1 ].get_tensor_value ()
730
756
perm = trans .get_attr ('perm' ).ints
0 commit comments