Skip to content

Commit 803fd80

Browse files
committed
Support exp op in XNNPACK backend
1 parent 5136175 commit 803fd80

File tree

12 files changed

+506
-3
lines changed

12 files changed

+506
-3
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_dynamic_dequantize_ops,
2020
op_dynamic_quantize_ops,
2121
op_elu,
22+
op_exp,
2223
op_floor,
2324
op_gelu,
2425
op_hardswish,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNExp,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class ExpVisitor(NodeVisitor):
24+
target = "aten.exp.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNExp(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
ConstantPadConfig,
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
28+
# EluConfig,
29+
ExpConfig,
2830
FloorConfig,
2931
GeluConfig,
3032
HardswishConfig,
31-
# EluConfig,
3233
HardtanhConfig,
3334
LeakyReLUConfig,
3435
LogConfig,
@@ -79,6 +80,7 @@
7980
ClampConfig,
8081
DivConfig,
8182
# EluConfig, # Waiting for PyTorch Pin Update
83+
ExpConfig,
8284
FloorConfig,
8385
GeluConfig,
8486
HardtanhConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336336
return torch.ops.aten.upsample_bilinear2d.vec
337337

338338

339+
class ExpConfig(GenericNodePartitionerConfig):
340+
target_name = "exp.default"
341+
342+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
343+
return [ConfigPrecisionType.FP32]
344+
345+
339346
class FloorConfig(GenericNodePartitionerConfig):
340347
target_name = "floor.default"
341348

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
exir_ops.edge.aten.rsqrt.default,
6767
exir_ops.edge.aten.log.default,
6868
exir_ops.edge.aten.gelu.default,
69+
exir_ops.edge.aten.exp.default,
6970
]
7071

7172
SUPPORTED_MODULES = [

0 commit comments

Comments
 (0)