1010
1111import  torch 
1212import  torch .fx 
13- from  executorch .backends .arm ._passes .arm_pass_utils  import  create_node 
13+ from  executorch .backends .arm ._passes .arm_pass_utils  import  (
14+     create_node ,
15+     get_node_arg ,
16+     set_node_arg ,
17+ )
1418from  executorch .exir .dialects ._ops  import  ops  as  exir_ops 
1519from  executorch .exir .pass_base  import  ExportPass , PassResult 
1620
1721
18- class  InsertSqueezeAfterSumPass (ExportPass ):
22+ class  KeepDimsFalseToSqueezePass (ExportPass ):
1923    """ 
20-     In Pytorch, the default behaviour of Tensor.sum is to squeeze 
24+     In Pytorch, the default behaviour of for example  Tensor.sum is to squeeze 
2125    the dimension that is summed (keep_dim = False). 
2226    However, in TOSA, REDUCE_SUM always preserves the 
2327    rank of the input (keep_dim = True). 
@@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
3135        squeeze(dim = dims) 
3236    """ 
3337
38+     # CURRENTLY NOT HANDLED OPS 
39+     # exir_ops.edge.aten.amax, 
40+     # exir_ops.edge.aten.amin, 
41+     # exir_ops.edge.aten.any.dim, 
42+     # exir_ops.edge.aten.any.dims, 
43+     # exir_ops.edge.aten.argmax, 
44+     # exir_ops.edge.aten.argmin, 
45+     # exir_ops.edge.aten.max.dim, 
46+     # exir_ops.edge.aten.min.dim, 
47+     # exir_ops.edge.aten.prod.dim_int, 
48+ 
49+     # HANDLED OPS 
50+     # exir_ops.edge.aten.sum.dim_IntList 
51+     # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass) 
52+     # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass) 
53+     # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass) 
54+ 
3455    def  call (self , graph_module : torch .fx .GraphModule ):
3556        for  node  in  graph_module .graph .nodes :
57+             keep_dim_index  =  None 
58+ 
3659            if  node .op  !=  "call_function" :
3760                continue 
38-             if  node .target  !=  exir_ops .edge .aten .sum .dim_IntList :
61+             if  node .target  ==  exir_ops .edge .aten .sum .dim_IntList :
62+                 keep_dim_index  =  2 
63+             else :
3964                continue 
65+ 
4066            sum_node  =  cast (torch .fx .Node , node )
41-             keep_dim  =  cast (bool , sum_node .args [2 ] if  len (sum_node .args ) >  2  else  False )
67+             keep_dim  =  get_node_arg (sum_node .args , keep_dim_index , False )
68+ 
4269            if  keep_dim :
4370                continue 
4471
45-             dim_list  =  cast ( list [ int ],  sum_node .args [ 1 ])
72+             dim_list  =  get_node_arg ( sum_node .args ,  1 , [ 0 ])
4673
4774            # Add keep_dim = True arg to sum node. 
48-             sum_node . args   =   sum_node . args [ 0 : 2 ]  +  ( True , )
75+             set_node_arg ( sum_node ,  2 ,  True )
4976
5077            with  graph_module .graph .inserting_after (sum_node ):
5178                squeeze_node  =  create_node (
5279                    graph_module .graph , exir_ops .edge .aten .squeeze_copy .dims , ()
5380                )
5481                sum_node .replace_all_uses_with (squeeze_node )
5582                squeeze_node .args  =  (sum_node , dim_list )
83+ 
5684        graph_module .graph .eliminate_dead_code ()
5785        graph_module .recompile ()
5886        graph_module  =  super ().call (graph_module ).graph_module 
0 commit comments