diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 35ce639978d..7eafe91eaea 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -19,6 +19,7 @@ op_dynamic_dequantize_ops, op_dynamic_quantize_ops, op_elu, + op_exp, op_floor, op_gelu, op_hardswish, diff --git a/backends/xnnpack/operators/op_exp.py b/backends/xnnpack/operators/op_exp.py new file mode 100644 index 00000000000..e9c6444ad5a --- /dev/null +++ b/backends/xnnpack/operators/op_exp.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNExp, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class ExpVisitor(NodeVisitor): + target = "aten.exp.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNExp( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 207d2cfd713..1ccde96ec16 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -25,10 +25,11 @@ ConstantPadConfig, DeQuantizedPerTensorConfig, DivConfig, + # EluConfig, + ExpConfig, FloorConfig, GeluConfig, HardswishConfig, - # EluConfig, HardtanhConfig, LeakyReLUConfig, LogConfig, @@ -80,6 +81,7 @@ ClampConfig, DivConfig, # EluConfig, # Waiting for PyTorch Pin Update + ExpConfig, FloorConfig, GeluConfig, HardtanhConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index bfaa232cdd4..dd4f239ca48 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -348,6 +348,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: return torch.ops.aten.upsample_bilinear2d.vec +class ExpConfig(GenericNodePartitionerConfig): + target_name = "exp.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class FloorConfig(GenericNodePartitionerConfig): target_name = "floor.default" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 246f571c9c8..eb31384c7ec 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -67,6 +67,7 @@ exir_ops.edge.aten.log.default, exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.exp.default, ] SUPPORTED_MODULES = [ diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index b724ab1a9d9..e41933324e2 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode( return Error::Ok; } +/* +Define serialized exp node into the subgraph, using the remapped ids +to map the serialized ids, to the new ids generated when defining the +tensor value +*/ +Error defineExpNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNExp(); + + xnn_status status = xnn_define_exp( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create exp node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Define serialized tanh node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1733,6 +1763,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Clamp) _DEFINE(LeakyReLU) _DEFINE(ELU) + _DEFINE(Exp) _DEFINE(Abs) _DEFINE(Floor) _DEFINE(PReLU) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index eea4cdb8b86..950318f18dc 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -155,6 +155,7 @@ union XNodeUnion { XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, + XNNExp: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index ed444005c64..a4efc627cbb 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -151,6 +151,7 @@ union XNodeUnion { XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, + XNNExp: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 106eb6b81d9..99b64708f86 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1): pass +@dataclass +class XNNExp(XNNNode1x1): + pass + + @dataclass class XNNGelu(XNNNode1x1): pass diff --git a/backends/xnnpack/test/ops/test_exp.py b/backends/xnnpack/test/ops/test_exp.py new file mode 100644 index 00000000000..8646a26cc62 --- /dev/null +++ b/backends/xnnpack/test/ops/test_exp.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestExp(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Exp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.exp(x) + + def run_exp_test(self, inputs): + ( + Tester(self.Exp(), inputs) + .export() + .check_count({"torch.ops.aten.exp.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_exp(self): + inputs = (torch.randn(20).to(torch.float16),) + self.run_exp_test(inputs) + + def test_fp32_exp(self): + inputs = (torch.randn(20),) + self.run_exp_test(inputs)