@@ -174,6 +174,7 @@ def _initialize_handlers(self):
174
174
"Cast" : self ._simple_through_handler ,
175
175
"Clip" : self ._simple_through_handler ,
176
176
"Concat" : self ._concat_handler ,
177
+ "Elu" : self ._simple_through_handler ,
177
178
"Identity" : self ._identity_handler ,
178
179
"LeakyRelu" : self ._simple_through_handler ,
179
180
"Max" : self ._maxmin_handler ,
@@ -247,14 +248,6 @@ def _switch_transpose_and_node(self, node, trans):
247
248
shape = self ._g .get_shape (node .output [0 ])
248
249
if shape :
249
250
# only nhwc transpose can reach here
250
- # if slicing results in non 4-D shape, we cannot infer the outputshape
251
- # JRP
252
- # new_shape = shape
253
- # if len(shape) == len(NHWC_TO_NCHW):
254
- # new_shape = [shape[i] for i in NHWC_TO_NCHW]
255
- # else:
256
- # new_shape = [-1]*len(shape)
257
- # self.logger.warning("%s's shape is unknown, which may interfere further optimization", node.output[0])
258
251
new_shape = [shape [i ] for i in NHWC_TO_NCHW ]
259
252
self ._g .set_shape (node .output [0 ], new_shape )
260
253
return True
@@ -526,7 +519,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
526
519
squeeze_shape = self ._g .get_shape (node .output [0 ])
527
520
self ._g .set_shape (trans .output [0 ], squeeze_shape )
528
521
input_shape = self ._g .get_shape (node .input [0 ])
529
- if input_shape is not None : # JRP and len(input_shape) == 4:
522
+ if input_shape is not None :
530
523
new_squeeze_output_shape = [input_shape [i ] for i in range (4 ) if i not in new_squeeze_axes ]
531
524
else :
532
525
new_squeeze_output_shape = [- 1 ] * 4
@@ -576,9 +569,6 @@ def _slice_handler(self, trans, node):
576
569
else : # in opset 10, axes is input instead of an attribute.
577
570
if len (node .inputs ) >= 4 and node .inputs [3 ].is_const ():
578
571
axes = node .inputs [3 ].get_tensor_value (as_list = True )
579
- # JRP
580
- #shape = self._g.get_shape(node.output[0])
581
- #if axes == [0, 1, 2, 3] and len(shape) == 4:
582
572
if axes == [0 , 1 , 2 , 3 ]:
583
573
# axes node might be shared
584
574
new_axes = np .array (NCHW_TO_NHWC , dtype = np .int64 )
0 commit comments