From 45b8900aca576536c3dc9af0e874fa67a3d92904 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 27 Jan 2025 17:27:05 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- backends/xnnpack/operators/__init__.py | 1 + backends/xnnpack/operators/op_rsqrt.py | 52 +++++++++++++++++++ backends/xnnpack/partition/config/__init__.py | 2 + .../partition/config/generic_node_configs.py | 7 +++ backends/xnnpack/partition/configs.py | 1 + backends/xnnpack/runtime/XNNCompiler.cpp | 31 +++++++++++ .../xnnpack/serialization/runtime_schema.fbs | 1 + backends/xnnpack/serialization/schema.fbs | 1 + .../serialization/xnnpack_graph_schema.py | 6 +++ backends/xnnpack/test/ops/test_rsqrt.py | 42 +++++++++++++++ 10 files changed, 144 insertions(+) create mode 100644 backends/xnnpack/operators/op_rsqrt.py create mode 100644 backends/xnnpack/test/ops/test_rsqrt.py diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index b2653a5fdc7..e199c95b0f0 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -37,6 +37,7 @@ op_prelu, op_quantize_per_tensor, op_relu, + op_rsqrt, op_sdpa, op_sigmoid, op_skip_ops, diff --git a/backends/xnnpack/operators/op_rsqrt.py b/backends/xnnpack/operators/op_rsqrt.py new file mode 100644 index 00000000000..451e202549a --- /dev/null +++ b/backends/xnnpack/operators/op_rsqrt.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNReciprocalSquareRoot, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class SquareRootVisitor(NodeVisitor): + target = "aten.rsqrt.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNReciprocalSquareRoot( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index ed105dc1f53..0ffb80f4e3c 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -39,6 +39,7 @@ PermuteConfig, PowConfig, QuantizedPerTensorConfig, + ReciprocalSquareRootConfig, ReLUConfig, # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, @@ -92,6 +93,7 @@ PermuteConfig, PowConfig, PreluConfig, + ReciprocalSquareRootConfig, ReLUConfig, # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index dbcb5c92035..c0e707474d3 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -482,6 +482,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class ReciprocalSquareRootConfig(GenericNodePartitionerConfig): + target_name = "rsqrt.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class ConstantPadConfig(GenericNodePartitionerConfig): target_name = "constant_pad_nd.default" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index ad4af24d3fc..65fb5ee48e4 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -63,6 +63,7 @@ exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm + exir_ops.edge.aten.rsqrt.default, ] SUPPORTED_MODULES = [ diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index dd8c66b3bfb..8d8e9a13152 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1345,6 +1345,36 @@ Error defineSquareRootNode( return Error::Ok; } +/* +Define serialized square root node into the subgraph, using the remapped ids +to map the serialized ids, to the new ids generated when defining the +tensor value +*/ +Error defineReciprocalSquareRootNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot(); + + xnn_status status = xnn_define_reciprocal_square_root( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create reciprocal square root node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Define serialized ceiling node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1904,6 +1934,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(StaticReshape) _DEFINE(ArgMaxPooling2d) _DEFINE(SquareRoot) + _DEFINE(ReciprocalSquareRoot) _DEFINE(Ceiling) _DEFINE(Hardswish) _DEFINE(LeakyReLU) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 8ba346d9bc0..11cb48430ed 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -138,6 +138,7 @@ union XNodeUnion { XNNBatchMatrixMultiply: _XNNNode2x1, XNNConcatenate5: _XNNCat, XNNConvTranspose2d: _XNNNodeConv, + XNNReciprocalSquareRoot: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 81263825ff5..5a43481b98d 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -134,6 +134,7 @@ union XNodeUnion { XNNBatchMatrixMultiply: _XNNNode2x1, XNNConcatenate5: _XNNCat, XNNConvTranspose2d: _XNNNodeConv, + XNNReciprocalSquareRoot: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 7c23a75507d..3276dac7869 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -281,6 +281,11 @@ class XNNSquareRoot(XNNNode1x1): pass +@dataclass +class XNNReciprocalSquareRoot(XNNNode1x1): + pass + + @dataclass class XNNCeiling(XNNNode1x1): pass @@ -373,6 +378,7 @@ class XNNScaledDotProductAttention: XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply, + XNNReciprocalSquareRoot, ] diff --git a/backends/xnnpack/test/ops/test_rsqrt.py b/backends/xnnpack/test/ops/test_rsqrt.py new file mode 100644 index 00000000000..e5d704a0467 --- /dev/null +++ b/backends/xnnpack/test/ops/test_rsqrt.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestRsqrt(unittest.TestCase): + class Rsqrt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.abs(x) + z = torch.rsqrt(x) + return z + + def _test_rsqrt(self, inputs): + ( + Tester(self.Rsqrt(), inputs) + .export() + .check_count({"torch.ops.aten.rsqrt.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_rsqrt_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_rsqrt(self): + inputs = (torch.randn(20).to(torch.float16),) + self._test_rsqrt(inputs) + + def test_fp32_rsqrt(self): + inputs = (torch.randn(20),) + self._test_rsqrt(inputs) From 3e124c02aec614e3d9f2e1165f57aa459ec36397 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 28 Jan 2025 07:54:09 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- backends/xnnpack/operators/op_rsqrt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_rsqrt.py b/backends/xnnpack/operators/op_rsqrt.py index 451e202549a..6ff8ab42c68 100644 --- a/backends/xnnpack/operators/op_rsqrt.py +++ b/backends/xnnpack/operators/op_rsqrt.py @@ -20,7 +20,7 @@ @register_node_visitor -class SquareRootVisitor(NodeVisitor): +class ReciprocalSquareRootVisitor(NodeVisitor): target = "aten.rsqrt.default" def __init__(self, *args) -> None: