Skip to content

Commit c222a44

Browse files
Arm backend: Add rsqrt lowering (pytorch#5577)
Summary: Change-Id: Ibecbd05907af4971af21e57beb90a205fbcc36c4 Pull Request resolved: pytorch#5577 Reviewed By: digantdesai Differential Revision: D63637418 Pulled By: mergennachin fbshipit-source-id: 824cf70ef2e90ad160eddab56ddfe4aafbaf6bbf
1 parent 09f13c0 commit c222a44

File tree

5 files changed

+184
-1
lines changed

5 files changed

+184
-1
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5858
exir_ops.edge.aten.mm.default,
5959
exir_ops.edge.aten.repeat.default,
6060
exir_ops.edge.aten.relu.default,
61+
exir_ops.edge.aten.rsqrt.default,
6162
exir_ops.edge.aten._softmax.default,
6263
exir_ops.edge.aten.slice_copy.Tensor,
6364
exir_ops.edge.aten.sub.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
op_quant,
2929
op_relu,
3030
op_repeat,
31+
op_rsqrt,
3132
op_sigmoid,
3233
op_slice,
3334
op_softmax,

backends/arm/operators/op_rsqrt.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import numpy as np
8+
import serializer.tosa_serializer as ts
9+
import torch
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_quant_utils import (
16+
dequantize_value,
17+
get_quant_node_args,
18+
QuantArgs,
19+
quantize_value,
20+
)
21+
from serializer.tosa_serializer import TosaOp
22+
23+
24+
@register_node_visitor
25+
class RsqrtVisitor(NodeVisitor):
26+
target = "aten.rsqrt.default"
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
if is_quant_node:
37+
# Assume quantized input is 8 bit.
38+
# Create attribute for 8 bit table lookup.
39+
input_node = node.all_input_nodes[0]
40+
in_quantargs = get_quant_node_args(input_node)
41+
output_node = list(node.users)[0]
42+
out_quantargs = get_quant_node_args(output_node)
43+
table = rsqrt_table_8bit(in_quantargs, out_quantargs)
44+
table_attr = ts.TosaSerializerAttribute()
45+
table_attr.TableAttribute(table)
46+
tosa_graph.addOperator(
47+
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
48+
)
49+
else:
50+
tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name])
51+
52+
53+
def rsqrt_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
54+
"""
55+
Returns a table mapping 256 entries to rqsrt([qmin,qmax])
56+
Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_rsqrt
57+
"""
58+
59+
def rqsrt(x):
60+
# Convert quantized input to floating point rqsrt input space.
61+
v = dequantize_value(x, in_quantargs)
62+
# Compute rqsrt.
63+
v = 1 / np.sqrt(v)
64+
# Convert rqsrt output back to quantized space.
65+
return quantize_value(v, out_quantargs)
66+
67+
return [
68+
rqsrt(x)
69+
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
70+
]

backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def _annotate_one_to_one(
3535
Typical ops are ops implemented with a lookup table.
3636
"""
3737
annotated_partitions = []
38-
one_to_one_ops = (torch.ops.aten.exp.default, torch.ops.aten.log.default)
38+
one_to_one_ops = {
39+
torch.ops.aten.exp.default,
40+
torch.ops.aten.log.default,
41+
torch.ops.aten.rsqrt.default,
42+
}
3943
for node in gm.graph.nodes:
4044
if node.op != "call_function" or node.target not in one_to_one_ops:
4145
continue
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Tests the rsqrt op.
8+
#
9+
10+
import unittest
11+
12+
import torch
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15+
from executorch.exir.backend.compile_spec_schema import CompileSpec
16+
from parameterized import parameterized
17+
18+
19+
class TestRsqrt(unittest.TestCase):
20+
class Rsqrt(torch.nn.Module):
21+
test_parameters = [
22+
(torch.ones(1, 10, 10, 10),),
23+
(torch.rand(1, 10, 10, 10),),
24+
(torch.rand(1, 5, 10, 20),),
25+
(torch.rand(5, 10, 20),),
26+
]
27+
28+
def forward(self, x: torch.Tensor):
29+
return x.rsqrt()
30+
31+
def _test_rsqrt_tosa_MI_pipeline(
32+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
33+
):
34+
(
35+
ArmTester(
36+
module,
37+
example_inputs=test_data,
38+
compile_spec=common.get_tosa_compile_spec(),
39+
)
40+
.export()
41+
.check_count({"torch.ops.aten.rsqrt.default": 1})
42+
.to_edge()
43+
.partition()
44+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
45+
.to_executorch()
46+
.run_method_and_compare_outputs(inputs=test_data)
47+
)
48+
49+
def _test_rsqrt_tosa_BI_pipeline(
50+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
51+
):
52+
(
53+
ArmTester(
54+
module,
55+
example_inputs=test_data,
56+
compile_spec=common.get_tosa_compile_spec(),
57+
)
58+
.quantize()
59+
.export()
60+
.check_count({"torch.ops.aten.rsqrt.default": 1})
61+
.to_edge()
62+
.partition()
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64+
.to_executorch()
65+
.run_method_and_compare_outputs(inputs=test_data)
66+
)
67+
68+
def _test_rsqrt_ethosu_BI_pipeline(
69+
self,
70+
compile_spec: CompileSpec,
71+
module: torch.nn.Module,
72+
test_data: tuple[torch.Tensor],
73+
):
74+
(
75+
ArmTester(
76+
module,
77+
example_inputs=test_data,
78+
compile_spec=compile_spec,
79+
)
80+
.quantize()
81+
.export()
82+
.check_count({"torch.ops.aten.rsqrt.default": 1})
83+
.to_edge()
84+
.partition()
85+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86+
.to_executorch()
87+
)
88+
89+
@parameterized.expand(Rsqrt.test_parameters)
90+
def test_rsqrt_tosa_MI(self, test_tensor: torch.Tensor):
91+
self._test_rsqrt_tosa_MI_pipeline(self.Rsqrt(), (test_tensor,))
92+
93+
@parameterized.expand(Rsqrt.test_parameters)
94+
def test_rsqrt_tosa_BI(self, test_tensor: torch.Tensor):
95+
self._test_rsqrt_tosa_BI_pipeline(self.Rsqrt(), (test_tensor,))
96+
97+
@parameterized.expand(Rsqrt.test_parameters)
98+
def test_rsqrt_u55_BI(self, test_tensor: torch.Tensor):
99+
self._test_rsqrt_ethosu_BI_pipeline(
100+
common.get_u55_compile_spec(), self.Rsqrt(), (test_tensor,)
101+
)
102+
103+
@parameterized.expand(Rsqrt.test_parameters)
104+
def test_rsqrt_u85_BI(self, test_tensor: torch.Tensor):
105+
self._test_rsqrt_ethosu_BI_pipeline(
106+
common.get_u85_compile_spec(), self.Rsqrt(), (test_tensor,)
107+
)

0 commit comments

Comments
 (0)