Skip to content

Commit dcf549f

Browse files
Add var decomposition for Arm backend
- Add unittests for var, var.dim and var.correction - Add decmposition for var.correction. var and var.dim are both converted to var.correction earlier in the lowering. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Iafd90f13e762d9b198b674f6b4d5c7c4927a1bbc
1 parent fd363e0 commit dcf549f

File tree

6 files changed

+329
-12
lines changed

6 files changed

+329
-12
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
2222
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
23+
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2324
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2425
InsertSqueezeAfterSumPass,
2526
)
@@ -49,6 +50,7 @@ def transform_to_backend_pipeline(
4950
self.add_pass(SizeAdjustConv2DPass())
5051
self.add_pass(RemoveClonePass())
5152
self.add_pass(ConvertExpandCopyToRepeatPass())
53+
self.add_pass(DecomposeVarPass())
5254
self.add_pass(ConvertMeanDimToAveragePool())
5355
self.add_pass(DecomposeMeanDimPass())
5456
self.add_pass(DecomposeDivPass())
@@ -63,6 +65,8 @@ def transform_to_backend_pipeline(
6365
return self._transform(exported_program.graph_module)
6466

6567
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
68+
self.add_pass(DecomposeVarPass())
69+
self.add_pass(DecomposeMeanDimPass())
6670
self.add_pass(DecomposeDivPass())
6771
self.add_pass(ScalarsToAttributePass())
6872
self.add_pass(DecomposeMeanDimPass())
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
def get_var_decomposition(op) -> tuple:
14+
if op == exir_ops.edge.aten.var.correction:
15+
return (
16+
exir_ops.edge.aten.mean.dim,
17+
exir_ops.edge.aten.sub.Tensor,
18+
exir_ops.edge.aten.mul.Tensor,
19+
exir_ops.edge.aten.sum.dim_IntList,
20+
exir_ops.edge.aten.full.default,
21+
)
22+
if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim):
23+
return (
24+
torch.ops.aten.mean.dim,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.mul.Tensor,
27+
torch.ops.aten.sum.dim_IntList,
28+
torch.ops.aten.full,
29+
)
30+
raise RuntimeError(f"Can't get var decomposition for op {op}")
31+
32+
33+
class DecomposeVarPass(ExportPass):
34+
"""
35+
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
36+
37+
Example:
38+
y = var_correction(x, dim, keepdim, correction)
39+
Becomes:
40+
mean = mean(x, dim)
41+
diff = sub(x, mean)
42+
squared_diff = mul(diff, diff)
43+
sum = sum(squared_diff, dim)
44+
y = div(sum, max(0, N-correction))
45+
"""
46+
47+
def call_operator(self, op, args, kwargs, meta):
48+
if op not in (
49+
exir_ops.edge.aten.var.correction,
50+
torch.ops.aten.var.correction,
51+
torch.ops.aten.var.dim,
52+
):
53+
return super().call_operator(op, args, kwargs, meta)
54+
shape = meta["val"].size()
55+
dtype = meta["val"].dtype
56+
dim = args[1] if len(args) > 1 else list(range(len(shape)))
57+
if op == torch.ops.aten.var.dim:
58+
correction = args[-2]
59+
keepdim = args[-1]
60+
else:
61+
correction = kwargs["correction"]
62+
keepdim = kwargs.get("keepdim", False)
63+
if not keepdim:
64+
return super().call_operator(op, args, kwargs, meta)
65+
66+
x = args[0]
67+
input_shape = x.data.size()
68+
N = 1
69+
for d in dim:
70+
N *= input_shape[d]
71+
72+
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
73+
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
74+
diff = super().call_operator(diff_op, (x, mean), {}, meta)
75+
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
76+
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
77+
full = super().call_operator(
78+
full_op,
79+
([1 for _ in shape], 1 / max(0, N - correction)),
80+
{"dtype": dtype},
81+
meta,
82+
)
83+
return super().call_operator(mul_op, (sum, full), {}, meta)

backends/arm/arm_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6767
exir_ops.edge.aten.view_copy.default,
6868
exir_ops.edge.aten.clone.default,
6969
exir_ops.edge.aten.mean.dim,
70+
exir_ops.edge.aten.var.correction,
7071
exir_ops.edge.aten.unsqueeze_copy.default,
7172
exir_ops.edge.aten.squeeze_copy.dims,
7273
operator.getitem,
@@ -87,6 +88,9 @@ def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
8788
if node.target == exir_ops.edge.aten.mean.dim:
8889
keep_dim = node.args[2] if len(node.args) > 2 else False
8990
return keep_dim
91+
if node.target == exir_ops.edge.aten.var.correction:
92+
keep_dim = node.kwargs.get("keepdim", False)
93+
return keep_dim
9094
return True
9195

9296

backends/arm/operators/op_full.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,7 @@ def define_node(
5252
dtype = ts.DType.FP32
5353
data = np.full(shape, value, dtype=np.float32)
5454

55-
tosa_graph.addConst(shape, dtype, data, "full-const")
56-
tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name])
55+
tosa_graph.addConst(shape, dtype, data, node.name + "full-const")
56+
tosa_graph.addOperator(
57+
ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name]
58+
)

backends/arm/quantizer/quantization_annotation/sub_annotator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
# pyre-unsafe
88

9-
import itertools
10-
import operator
119
from typing import Callable, List, Optional
1210

1311
import torch
@@ -16,7 +14,6 @@
1614
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1715
from torch.ao.quantization.quantizer import QuantizationAnnotation
1816
from torch.fx import GraphModule, Node
19-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2017

2118

2219
@register_annotator("sub")
@@ -25,14 +22,12 @@ def _annotate_sub(
2522
quantization_config: QuantizationConfig,
2623
filter_fn: Optional[Callable[[Node], bool]] = None,
2724
) -> Optional[List[List[Node]]]:
28-
sub_partitions = get_source_partitions(
29-
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn
30-
)
31-
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values()))
3225
annotated_partitions = []
33-
for sub_partition in sub_partitions:
34-
annotated_partitions.append(sub_partition.nodes)
35-
sub_node = sub_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.sub.Tensor,):
28+
continue
29+
annotated_partitions.append(node)
30+
sub_node = node
3631
if arm_quantizer_utils.is_annotated(sub_node):
3732
continue
3833

0 commit comments

Comments
 (0)