diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index f056ad8b086..ae3effc2ce7 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -24,6 +24,7 @@ op_hardtanh, op_leaky_relu, op_linear, + op_log, op_matrix_multiplication, op_max_dim, op_max_pool2d, diff --git a/backends/xnnpack/operators/op_log.py b/backends/xnnpack/operators/op_log.py new file mode 100644 index 00000000000..edafadf4a27 --- /dev/null +++ b/backends/xnnpack/operators/op_log.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNLog, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class LogVisitor(NodeVisitor): + target = "aten.log.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNLog( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 0ffb80f4e3c..a8bc9e6e4a0 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -30,6 +30,7 @@ # EluConfig, HardtanhConfig, LeakyReLUConfig, + LogConfig, MaximumConfig, MaxPool2dConfig, MeanDimConfig, @@ -82,6 +83,7 @@ HardswishConfig, LeakyReLUConfig, LinearConfig, + LogConfig, MaxDimConfig, MaximumConfig, MaxPool2dConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index c0e707474d3..e16698a3ae6 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -357,6 +357,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class LogConfig(GenericNodePartitionerConfig): + target_name = "log.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class MeanDimConfig(GenericNodePartitionerConfig): target_name = "mean.dim" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 65fb5ee48e4..60b270134bd 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -64,6 +64,7 @@ exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.log.default, ] SUPPORTED_MODULES = [ diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 3b78c1a0b84..445744e9918 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1418,6 +1418,36 @@ Error defineReciprocalSquareRootNode( return Error::Ok; } +/* +Define serialized log node into the subgraph, using the remapped ids +to map the serialized ids, to the new ids generated when defining the +tensor value +*/ +Error defineLogNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNLog(); + + xnn_status status = xnn_define_log( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create log node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Define serialized ceiling node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1981,6 +2011,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Ceiling) _DEFINE(Hardswish) _DEFINE(LeakyReLU) + _DEFINE(Log) _DEFINE(Maximum) _DEFINE(Negate) _DEFINE(Square) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 75074107c55..f10ba3d1b81 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -139,6 +139,7 @@ union XNodeUnion { XNNConcatenate5: _XNNCat, XNNConvTranspose2d: _XNNNodeConv, XNNReciprocalSquareRoot: _XNNNode1x1, + XNNLog: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 193656c30b1..565eb4c3bba 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -135,6 +135,7 @@ union XNodeUnion { XNNConcatenate5: _XNNCat, XNNConvTranspose2d: _XNNNodeConv, XNNReciprocalSquareRoot: _XNNNode1x1, + XNNLog: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 3cb572c66ef..2a3ccaf2a0a 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -309,6 +309,11 @@ class XNNLeakyReLU: flags: int +@dataclass +class XNNLog(XNNNode1x1): + pass + + @dataclass class XNNMaximum(XNNNode2x1): pass @@ -379,6 +384,7 @@ class XNNScaledDotProductAttention: XNNScaledDotProductAttention, XNNBatchMatrixMultiply, XNNReciprocalSquareRoot, + XNNLog, ] diff --git a/backends/xnnpack/test/ops/test_log.py b/backends/xnnpack/test/ops/test_log.py new file mode 100644 index 00000000000..a0670158b7f --- /dev/null +++ b/backends/xnnpack/test/ops/test_log.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestLog(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Log(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.abs(x) + z = torch.log(x) + return z + + def run_log_test(self, inputs): + ( + Tester(self.Log(), inputs) + .export() + .check_count({"torch.ops.aten.log.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_log_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_log(self): + inputs = (torch.randn(20).to(torch.float16),) + self.run_log_test(inputs) + + def test_fp32_log(self): + inputs = (torch.randn(20),) + self.run_log_test(inputs)