1010
1111import torch
1212import torch .fx
13- from executorch .backends .arm ._passes .arm_pass_utils import create_node
13+ from executorch .backends .arm ._passes .arm_pass_utils import (
14+ create_node ,
15+ get_node_arg ,
16+ set_node_arg ,
17+ )
1418from executorch .exir .dialects ._ops import ops as exir_ops
1519from executorch .exir .pass_base import ExportPass , PassResult
1620
1721
18- class InsertSqueezeAfterSumPass (ExportPass ):
22+ class KeepDimsFalseToSqueezePass (ExportPass ):
1923 """
20- In Pytorch, the default behaviour of Tensor.sum is to squeeze
24+ In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
2125 the dimension that is summed (keep_dim = False).
2226 However, in TOSA, REDUCE_SUM always preserves the
2327 rank of the input (keep_dim = True).
@@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
3135 squeeze(dim = dims)
3236 """
3337
38+ # CURRENTLY NOT HANDLED OPS
39+ # exir_ops.edge.aten.amax,
40+ # exir_ops.edge.aten.amin,
41+ # exir_ops.edge.aten.any.dim,
42+ # exir_ops.edge.aten.any.dims,
43+ # exir_ops.edge.aten.argmax,
44+ # exir_ops.edge.aten.argmin,
45+ # exir_ops.edge.aten.max.dim,
46+ # exir_ops.edge.aten.min.dim,
47+ # exir_ops.edge.aten.prod.dim_int,
48+
49+ # HANDLED OPS
50+ # exir_ops.edge.aten.sum.dim_IntList
51+ # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
52+ # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
53+ # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)
54+
3455 def call (self , graph_module : torch .fx .GraphModule ):
3556 for node in graph_module .graph .nodes :
57+ keep_dim_index = None
58+
3659 if node .op != "call_function" :
3760 continue
38- if node .target != exir_ops .edge .aten .sum .dim_IntList :
61+ if node .target == exir_ops .edge .aten .sum .dim_IntList :
62+ keep_dim_index = 2
63+ else :
3964 continue
65+
4066 sum_node = cast (torch .fx .Node , node )
41- keep_dim = cast (bool , sum_node .args [2 ] if len (sum_node .args ) > 2 else False )
67+ keep_dim = get_node_arg (sum_node .args , keep_dim_index , False )
68+
4269 if keep_dim :
4370 continue
4471
45- dim_list = cast ( list [ int ], sum_node .args [ 1 ])
72+ dim_list = get_node_arg ( sum_node .args , 1 , [ 0 ])
4673
4774 # Add keep_dim = True arg to sum node.
48- sum_node . args = sum_node . args [ 0 : 2 ] + ( True , )
75+ set_node_arg ( sum_node , 2 , True )
4976
5077 with graph_module .graph .inserting_after (sum_node ):
5178 squeeze_node = create_node (
5279 graph_module .graph , exir_ops .edge .aten .squeeze_copy .dims , ()
5380 )
5481 sum_node .replace_all_uses_with (squeeze_node )
5582 squeeze_node .args = (sum_node , dim_list )
83+
5684 graph_module .graph .eliminate_dead_code ()
5785 graph_module .recompile ()
5886 graph_module = super ().call (graph_module ).graph_module
0 commit comments