@@ -500,23 +500,27 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
500
500
return False
501
501
502
502
if node .get_attr ("axes" ):
503
- squeeze_axes = sorted (list (node .get_attr ("axes" ).ints ))
504
- trans_perm = list (trans .get_attr ("perm" ).ints )
505
- squeeze_shape = self ._g .get_shape (node .output [0 ])
506
503
# switch tran and squeeze
507
504
# 1 switch
508
505
ops = self ._g .get_nodes ()
509
506
self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
510
507
node .input [0 ] = trans .input [0 ]
511
508
trans .input [0 ] = node .output [0 ]
512
509
# 2 correct attr of nodes
510
+ squeeze_axes = sorted (list (node .get_attr ("axes" ).ints ))
511
+ trans_perm = list (trans .get_attr ("perm" ).ints )
513
512
new_perm , new_squeeze_axes = _calculate_new_attr (ori_perm = trans_perm , ori_squeeze_axes = squeeze_axes )
514
513
trans .set_attr ("perm" , new_perm )
515
514
node .set_attr ("axes" , new_squeeze_axes )
516
515
# 3 set shape
516
+ squeeze_shape = self ._g .get_shape (node .output [0 ])
517
517
self ._g .set_shape (trans .output [0 ], squeeze_shape )
518
518
input_shape = self ._g .get_shape (node .input [0 ])
519
- new_squeeze_output_shape = [input_shape [i ] for i in range (4 ) if i not in new_squeeze_axes ]
519
+ if input_shape is not None :
520
+ new_squeeze_output_shape = [input_shape [i ] for i in range (4 ) if i not in new_squeeze_axes ]
521
+ else :
522
+ new_squeeze_output_shape = [- 1 ]* 4
523
+ self .logger .warning ("%s's shape is unknown, which may interfere further optimization" , node .input [0 ])
520
524
self ._g .set_shape (node .output [0 ], new_squeeze_output_shape )
521
525
return True
522
526
return False
0 commit comments