@@ -247,6 +247,14 @@ def _switch_transpose_and_node(self, node, trans):
247
247
shape = self ._g .get_shape (node .output [0 ])
248
248
if shape :
249
249
# only nhwc transpose can reach here
250
+ # if slicing results in non 4-D shape, we cannot infer the outputshape
251
+ # JRP
252
+ # new_shape = shape
253
+ # if len(shape) == len(NHWC_TO_NCHW):
254
+ # new_shape = [shape[i] for i in NHWC_TO_NCHW]
255
+ # else:
256
+ # new_shape = [-1]*len(shape)
257
+ # self.logger.warning("%s's shape is unknown, which may interfere further optimization", node.output[0])
250
258
new_shape = [shape [i ] for i in NHWC_TO_NCHW ]
251
259
self ._g .set_shape (node .output [0 ], new_shape )
252
260
return True
@@ -518,7 +526,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
518
526
squeeze_shape = self ._g .get_shape (node .output [0 ])
519
527
self ._g .set_shape (trans .output [0 ], squeeze_shape )
520
528
input_shape = self ._g .get_shape (node .input [0 ])
521
- if input_shape is not None :
529
+ if input_shape is not None : # JRP and len(input_shape) == 4:
522
530
new_squeeze_output_shape = [input_shape [i ] for i in range (4 ) if i not in new_squeeze_axes ]
523
531
else :
524
532
new_squeeze_output_shape = [- 1 ] * 4
@@ -568,6 +576,9 @@ def _slice_handler(self, trans, node):
568
576
else : # in opset 10, axes is input instead of an attribute.
569
577
if len (node .inputs ) >= 4 and node .inputs [3 ].is_const ():
570
578
axes = node .inputs [3 ].get_tensor_value (as_list = True )
579
+ # JRP
580
+ #shape = self._g.get_shape(node.output[0])
581
+ #if axes == [0, 1, 2, 3] and len(shape) == 4:
571
582
if axes == [0 , 1 , 2 , 3 ]:
572
583
# axes node might be shared
573
584
new_axes = np .array (NCHW_TO_NHWC , dtype = np .int64 )
0 commit comments