Skip to content

Commit 966b643

Browse files
authored
Support exp op in XNNPACK backend (#11803)
### Summary This PR adds support for the exp operator in ExecuTorch via XNNPACK, enabling optimized execution of torch.exp on the XNNPACK backend. The implementation includes updates to operator configuration, serialization, and runtime handling. The exp operator is now properly registered in the XNNPACK partition config and mapped to XNNPACK's xnn_create_exp_operator API in the compiler. ### Test plan I added a new test class TestExp that is a simple torch model with an exp op. It then asserts that the XNNPACK delegate was called while executing the exp op instead of the torch default exp op.
1 parent c70047f commit 966b643

File tree

10 files changed

+145
-1
lines changed

10 files changed

+145
-1
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_dynamic_dequantize_ops,
2020
op_dynamic_quantize_ops,
2121
op_elu,
22+
op_exp,
2223
op_floor,
2324
op_gelu,
2425
op_hardswish,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNExp,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class ExpVisitor(NodeVisitor):
24+
target = "aten.exp.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNExp(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
ConstantPadConfig,
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
28+
# EluConfig,
29+
ExpConfig,
2830
FloorConfig,
2931
GeluConfig,
3032
HardswishConfig,
31-
# EluConfig,
3233
HardtanhConfig,
3334
LeakyReLUConfig,
3435
LogConfig,
@@ -80,6 +81,7 @@
8081
ClampConfig,
8182
DivConfig,
8283
# EluConfig, # Waiting for PyTorch Pin Update
84+
ExpConfig,
8385
FloorConfig,
8486
GeluConfig,
8587
HardtanhConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
348348
return torch.ops.aten.upsample_bilinear2d.vec
349349

350350

351+
class ExpConfig(GenericNodePartitionerConfig):
352+
target_name = "exp.default"
353+
354+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
355+
return [ConfigPrecisionType.FP32]
356+
357+
351358
class FloorConfig(GenericNodePartitionerConfig):
352359
target_name = "floor.default"
353360

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
exir_ops.edge.aten.log.default,
6868
exir_ops.edge.aten.gelu.default,
6969
exir_ops.edge.aten.tanh.default,
70+
exir_ops.edge.aten.exp.default,
7071
]
7172

7273
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode(
11621162
return Error::Ok;
11631163
}
11641164

1165+
/*
1166+
Define serialized exp node into the subgraph, using the remapped ids
1167+
to map the serialized ids, to the new ids generated when defining the
1168+
tensor value
1169+
*/
1170+
Error defineExpNode(
1171+
xnn_subgraph_t subgraph_ptr,
1172+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1173+
const NodePtr node,
1174+
const fb_xnnpack::XNNGraph* graph) noexcept {
1175+
MAYBE_UNUSED(graph);
1176+
1177+
auto graph_node = node->xnode_union_as_XNNExp();
1178+
1179+
xnn_status status = xnn_define_exp(
1180+
subgraph_ptr,
1181+
remapped_ids.at(graph_node->input_id()),
1182+
remapped_ids.at(graph_node->output_id()),
1183+
graph_node->flags());
1184+
1185+
ET_CHECK_OR_RETURN_ERROR(
1186+
status == xnn_status_success,
1187+
Internal,
1188+
"Failed to create exp node %i with code: %s",
1189+
node->debug_handle(),
1190+
xnn_status_to_string(status));
1191+
1192+
return Error::Ok;
1193+
}
1194+
11651195
/*
11661196
Define serialized tanh node into the subgraph, using the remapped ids
11671197
to map the serialized ids, to the new ids generated when defining the
@@ -1733,6 +1763,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17331763
_DEFINE(Clamp)
17341764
_DEFINE(LeakyReLU)
17351765
_DEFINE(ELU)
1766+
_DEFINE(Exp)
17361767
_DEFINE(Abs)
17371768
_DEFINE(Floor)
17381769
_DEFINE(PReLU)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ union XNodeUnion {
155155
XNNLog: _XNNNode1x1,
156156
XNNGelu: _XNNNode1x1,
157157
XNNTanh: _XNNNode1x1,
158+
XNNExp: _XNNNode1x1,
158159
}
159160

160161
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ union XNodeUnion {
151151
XNNLog: _XNNNode1x1,
152152
XNNGelu: _XNNNode1x1,
153153
XNNTanh: _XNNNode1x1,
154+
XNNExp: _XNNNode1x1,
154155
}
155156

156157
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291291
pass
292292

293293

294+
@dataclass
295+
class XNNExp(XNNNode1x1):
296+
pass
297+
298+
294299
@dataclass
295300
class XNNGelu(XNNNode1x1):
296301
pass
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestExp(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Exp(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
return torch.exp(x)
23+
24+
def run_exp_test(self, inputs):
25+
(
26+
Tester(self.Exp(), inputs)
27+
.export()
28+
.check_count({"torch.ops.aten.exp.default": 1})
29+
.to_edge_transform_and_lower()
30+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31+
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
32+
.to_executorch()
33+
.serialize()
34+
.run_method_and_compare_outputs()
35+
)
36+
37+
def test_fp16_exp(self):
38+
inputs = (torch.randn(20).to(torch.float16),)
39+
self.run_exp_test(inputs)
40+
41+
def test_fp32_exp(self):
42+
inputs = (torch.randn(20),)
43+
self.run_exp_test(inputs)

0 commit comments

Comments
 (0)