Skip to content

Commit 7ff87d3

Browse files
committed
Update base for Update on "Remove ExecuTorch copy of Vectorized"
All uses are outside ExecuTorch core, so we can just use ATen Vectorized. Differential Revision: [D66396016](https://our.internmc.facebook.com/intern/diff/D66396016/) [ghstack-poisoned]
2 parents 82184bd + ddec0c7 commit 7ff87d3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1798
-479
lines changed

backends/arm/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,14 @@ python_library(
110110
"//executorch/backends/arm/operators:node_visitor",
111111
],
112112
)
113+
114+
python_library(
115+
name = "arm_model_evaluator",
116+
src = [
117+
"util/arm_model_evaluator.py",
118+
],
119+
typing = True,
120+
deps = [
121+
"//caffe2:torch",
122+
]
123+
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32-
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
33-
InsertSqueezeAfterSumPass,
32+
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
33+
KeepDimsFalseToSqueezePass,
3434
)
3535
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
3636
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
@@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
7171
self.add_pass(DecomposeMeanDimPass())
7272
self.add_pass(MatchArgRanksPass(exported_program))
7373
self.add_pass(DecomposeDivPass())
74-
self.add_pass(InsertSqueezeAfterSumPass())
74+
self.add_pass(KeepDimsFalseToSqueezePass())
7575
self.add_pass(ConvertSplitToSlicePass())
7676
self.add_pass(Conv1dUnsqueezePass(exported_program))
7777
self.add_pass(DecomposeSoftmaxesPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-unsafe
99

10+
from inspect import isclass
1011
from typing import Optional
1112

1213
import torch
@@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
133134
fake_tensor, FakeTensor
134135
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
135136
return fake_tensor
137+
138+
139+
def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
140+
"""
141+
Help-function for getting a value from node.args/ kwargs, three cases:
142+
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
143+
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
144+
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
145+
"""
146+
if isinstance(key, int):
147+
if 0 <= key < len(args):
148+
return args[key]
149+
elif key == len(args):
150+
if default_value is not None:
151+
return default_value
152+
else:
153+
raise RuntimeError(f"No defult value given for index {key}")
154+
else:
155+
raise RuntimeError(
156+
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
157+
)
158+
elif isinstance(key, str):
159+
return args.get(key, default_value)
160+
elif isclass(key):
161+
for arg in args:
162+
if isinstance(arg, key):
163+
return arg
164+
if default_value is not None:
165+
return default_value
166+
else:
167+
raise RuntimeError(f"No arg of type {key}")
168+
else:
169+
raise RuntimeError("Invalid type")
170+
171+
172+
def set_node_arg(node: torch.fx.Node, i: int | str, value):
173+
"""
174+
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
175+
"""
176+
if isinstance(i, int):
177+
if 0 <= i < len(node.args):
178+
args = list(node.args)
179+
args[i] = value
180+
node.args = tuple(args)
181+
return
182+
elif i == len(node.args):
183+
node.args = node.args + (value,)
184+
else:
185+
raise RuntimeError(
186+
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
187+
)
188+
elif isinstance(i, str):
189+
kwargs = dict(node.kwargs)
190+
kwargs[i] = value
191+
node.kwargs = kwargs
192+
else:
193+
raise RuntimeError("Invalid type")

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass
1213

@@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
4243
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
4344
return super().call_operator(op, args, kwargs, meta)
4445

45-
x = args[0]
46-
dim = args[1]
47-
keepdim = args[2] if len(args) > 2 else False
48-
if not keepdim:
49-
return super().call_operator(op, args, kwargs, meta)
50-
# if keepdim == True and dim == [-1, -2], mean.dim can be
46+
x = get_node_arg(args, 0)
47+
dim = get_node_arg(args, 1)
48+
keepdim = get_node_arg(args, 2, False)
49+
50+
# if dim == [-1, -2], mean.dim can be
5151
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
5252
if dim == [-1, -2]:
5353
# Simply return the mean.dim operator for future decomposition.
5454
return super().call_operator(op, args, kwargs, meta)
55+
5556
shape = meta["val"].size()
5657
dtype = meta["val"].dtype
5758
input_shape = x.data.size()

backends/arm/_passes/decompose_var_pass.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from executorch.exir.pass_base import ExportPass
1314

@@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
5354
torch.ops.aten.var.dim,
5455
):
5556
return super().call_operator(op, args, kwargs, meta)
56-
shape = meta["val"].size()
57+
58+
x = args[0]
59+
input_shape = x.data.size()
60+
shape = list(meta["val"].size())
61+
if shape == []:
62+
shape = [1 for _ in input_shape]
63+
5764
dtype = meta["val"].dtype
58-
dim = args[1] if len(args) > 1 else list(range(len(shape)))
65+
# Get dim from args based on argument type
66+
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))
67+
5968
if op == torch.ops.aten.var.dim:
60-
correction = args[-2]
61-
keepdim = args[-1]
69+
keepdim = get_node_arg(args, bool, False)
70+
correction = get_node_arg(args, int, 1)
6271
else:
63-
correction = kwargs["correction"]
64-
keepdim = kwargs.get("keepdim", False)
65-
if not keepdim:
66-
return super().call_operator(op, args, kwargs, meta)
72+
correction = get_node_arg(kwargs, "correction", 1)
73+
keepdim = get_node_arg(kwargs, "keepdim", False)
6774

68-
x = args[0]
69-
input_shape = x.data.size()
7075
N = 1
7176
for d in dim:
7277
N *= input_shape[d]
7378

7479
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
75-
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
80+
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
7681
diff = super().call_operator(diff_op, (x, mean), {}, meta)
7782
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
7883
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)

backends/arm/_passes/insert_squeeze_after_sum_pass.py renamed to backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010

1111
import torch
1212
import torch.fx
13-
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.backends.arm._passes.arm_pass_utils import (
14+
create_node,
15+
get_node_arg,
16+
set_node_arg,
17+
)
1418
from executorch.exir.dialects._ops import ops as exir_ops
1519
from executorch.exir.pass_base import ExportPass, PassResult
1620

1721

18-
class InsertSqueezeAfterSumPass(ExportPass):
22+
class KeepDimsFalseToSqueezePass(ExportPass):
1923
"""
20-
In Pytorch, the default behaviour of Tensor.sum is to squeeze
24+
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
2125
the dimension that is summed (keep_dim = False).
2226
However, in TOSA, REDUCE_SUM always preserves the
2327
rank of the input (keep_dim = True).
@@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
3135
squeeze(dim = dims)
3236
"""
3337

38+
# CURRENTLY NOT HANDLED OPS
39+
# exir_ops.edge.aten.amax,
40+
# exir_ops.edge.aten.amin,
41+
# exir_ops.edge.aten.any.dim,
42+
# exir_ops.edge.aten.any.dims,
43+
# exir_ops.edge.aten.argmax,
44+
# exir_ops.edge.aten.argmin,
45+
# exir_ops.edge.aten.max.dim,
46+
# exir_ops.edge.aten.min.dim,
47+
# exir_ops.edge.aten.prod.dim_int,
48+
49+
# HANDLED OPS
50+
# exir_ops.edge.aten.sum.dim_IntList
51+
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
52+
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
53+
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)
54+
3455
def call(self, graph_module: torch.fx.GraphModule):
3556
for node in graph_module.graph.nodes:
57+
keep_dim_index = None
58+
3659
if node.op != "call_function":
3760
continue
38-
if node.target != exir_ops.edge.aten.sum.dim_IntList:
61+
if node.target == exir_ops.edge.aten.sum.dim_IntList:
62+
keep_dim_index = 2
63+
else:
3964
continue
65+
4066
sum_node = cast(torch.fx.Node, node)
41-
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
67+
keep_dim = get_node_arg(sum_node.args, keep_dim_index, False)
68+
4269
if keep_dim:
4370
continue
4471

45-
dim_list = cast(list[int], sum_node.args[1])
72+
dim_list = get_node_arg(sum_node.args, 1, [0])
4673

4774
# Add keep_dim = True arg to sum node.
48-
sum_node.args = sum_node.args[0:2] + (True,)
75+
set_node_arg(sum_node, 2, True)
4976

5077
with graph_module.graph.inserting_after(sum_node):
5178
squeeze_node = create_node(
5279
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
5380
)
5481
sum_node.replace_all_uses_with(squeeze_node)
5582
squeeze_node.args = (sum_node, dim_list)
83+
5684
graph_module.graph.eliminate_dead_code()
5785
graph_module.recompile()
5886
graph_module = super().call(graph_module).graph_module

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,4 @@
55

66
# pyre-unsafe
77

8-
from . import ( # noqa
9-
mean_dim_support,
10-
right_shift_support,
11-
tosa_supported_operators,
12-
var_correction_support,
13-
)
8+
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa

backends/arm/operator_support/mean_dim_support.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)