@@ -707,6 +707,117 @@ def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
707707 return cast (list [int ], permute_node .kwargs ["dim" ])
708708
709709
710+ @register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
711+ class RemoveSqueezeUnsqueezeAroundElementwiseOps (ExportPass ):
712+ """
713+ Looks for subgraphs of the form:
714+ unsqueeze -> [op] -> squeeze
715+ and removes the unsqueeze and squeeze nodes by reshaping the intermediate ops. Only
716+ handles simple chain of ops as intermediate for now.
717+
718+ The pass works on view ops instead of unsqueeze and squeeze directly, thus it
719+ should be run after the squeeze/unsqueeze->view lowering.
720+ """
721+
722+ intermediate_ops : set [EdgeOpOverload ] = {
723+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
724+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
725+ exir_ops .edge .cadence .quantize_per_tensor .default ,
726+ exir_ops .edge .cadence .dequantize_per_tensor .default ,
727+ # Ops that require special handling:
728+ exir_ops .edge .aten .slice_copy .Tensor ,
729+ }
730+
731+ def find_unsqueeze_dim (self , view_node : Node ) -> Optional [int ]:
732+ """
733+ Return the unsqueeze dim if the given view_copy op unsqueezes the input tensor,
734+ if not return None.
735+ """
736+ input_node = cast (Node , get_arg (view_node , 0 , "input" ))
737+ input_shape = input_node .meta ["val" ].shape
738+ output_shape = view_node .meta ["val" ].shape
739+ if len (output_shape ) != len (input_shape ) + 1 :
740+ return None
741+ for dim in range (len (output_shape )):
742+ if output_shape == input_shape [:dim ] + (1 ,) + input_shape [dim :]:
743+ return dim
744+ return None
745+
746+ def find_ancestor_squeeze (self , node : Node , squeeze_dim : int ) -> Optional [Node ]:
747+ """
748+ Traverse up from the given node until finding a squeeze node with the given
749+ squeeze_dim. If no such node is found, return None.
750+ """
751+ while True :
752+ # Only handle simple chains for now
753+ if len (node .users ) != 1 :
754+ return None
755+ if node .target in self .intermediate_ops :
756+ node = cast (Node , get_arg (node , 0 , "input" ))
757+ elif node .target == exir_ops .edge .aten .view_copy .default :
758+ input_node = cast (Node , get_arg (node , 0 , "input" ))
759+ input_shape = input_node .meta ["val" ].shape
760+ output_shape = node .meta ["val" ].shape
761+ # Check if the node is a squeeze op.
762+ if (
763+ len (input_shape ) != len (output_shape ) + 1
764+ or input_shape
765+ != output_shape [:squeeze_dim ] + (1 ,) + output_shape [squeeze_dim :]
766+ ):
767+ return None
768+ return node
769+ else :
770+ return None
771+
772+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
773+ changed = False
774+
775+ # Traverse the graph looking for unsqueeze-like view ops.
776+ for node in graph_module .graph .find_nodes (
777+ op = "call_function" , target = exir_ops .edge .aten .view_copy .default
778+ ):
779+ unsqueeze_dim = self .find_unsqueeze_dim (node )
780+ if unsqueeze_dim is None :
781+ continue
782+
783+ input_node = cast (Node , get_arg (node , 0 , "input" ))
784+ squeeze_node = self .find_ancestor_squeeze (input_node , unsqueeze_dim )
785+ if squeeze_node is None :
786+ continue
787+
788+ # Chain is found. Remove view ops and update the intermediate ops traversing
789+ # the chain.
790+ assert len (squeeze_node .users ) == 1
791+ node = next (iter (squeeze_node .users ))
792+
793+ # Skip first view_copy.
794+ squeeze_node .replace_all_uses_with (
795+ cast (Node , get_arg (squeeze_node , 0 , "input" ))
796+ )
797+
798+ # Go down the chain and update the intermediate ops if needed.
799+ while node .target != exir_ops .edge .aten .view_copy .default :
800+ if node .target == exir_ops .edge .aten .slice_copy .Tensor :
801+ slice_dim = cast (int , get_arg (node , 1 , "dim" , default = 0 ))
802+ if slice_dim < 0 :
803+ slice_dim += len (node .meta ["val" ].shape )
804+ if slice_dim >= unsqueeze_dim :
805+ set_arg (node , 1 , "dim" , slice_dim + 1 )
806+ assert len (node .users ) == 1
807+ node = next (iter (node .users ))
808+
809+ # Skip final view_copy.
810+ node .replace_all_uses_with (cast (Node , get_arg (node , 0 , "input" )))
811+
812+ changed = True
813+
814+ if changed :
815+ graph_module .graph .eliminate_dead_code ()
816+ graph_module .recompile ()
817+
818+ return PassResult (graph_module , changed )
819+
820+
710821@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
711822class RemoveBranchedQuantDequant (ExportPass ):
712823 """
0 commit comments