@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727 exir_ops .edge .aten .gelu .default ,
2828 }
2929
30+ def should_squeeze (self , op , shape : List [int ]) -> bool : # pyre-ignore
31+ if len (shape ) == 3 :
32+ return shape [1 ] == 1 and shape [0 ] > 1
33+ if len (shape ) == 4 :
34+ # No need to squeeze if all dims are 1 except the width dim
35+ if all (dim == 1 for dim in shape [:- 1 ]):
36+ return False
37+ # Otherwise, check for squeezable dim
38+ return 1 in shape [:- 1 ]
39+
40+ # Prefer not to introduce additional orchestration ops by default
41+ return False
42+
3043 def call_operator (
3144 self ,
3245 op , # pyre-ignore
3346 args : Tuple [Argument , ...],
3447 kwargs : Dict [str , Argument ],
3548 meta : NodeMetadata ,
3649 ) -> ProxyValue :
37- def _squeezable (shape : List [int ]) -> bool :
38- return len (shape ) > 2 and 1 in shape
39-
4050 if op not in self ._squeezable_ops :
4151 return super ().call_operator (op , args , kwargs , meta )
42-
4352 # pyre-ignore[16]: `None` has no attribute `node`
4453 input_shape = args [0 ].node .meta ["val" ].shape
4554 output_shape = meta ["val" ].shape
46- if not _squeezable (input_shape ):
55+
56+ if not self .should_squeeze (op , input_shape ):
4757 return super ().call_operator (op , args , kwargs , meta )
4858
59+ def _squeezable (shape : List [int ]) -> bool :
60+ return len (shape ) > 2 and 1 in shape
61+
4962 # squeeze input tensor
5063 squeeze_shape = list (input_shape )
5164 while _squeezable (squeeze_shape ):
0 commit comments