File tree Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Original file line number Diff line number Diff line change 1919 op_dynamic_dequantize_ops ,
2020 op_dynamic_quantize_ops ,
2121 op_elu ,
22+ op_exp ,
2223 op_floor ,
2324 op_gelu ,
2425 op_hardswish ,
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ from typing import Dict
8+
9+ import torch
10+ from executorch .backends .xnnpack .operators .node_visitor import (
11+ NodeVisitor ,
12+ register_node_visitor ,
13+ )
14+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15+ XNNExp ,
16+ XNNGraph ,
17+ XNode ,
18+ )
19+ from executorch .backends .xnnpack .utils .utils import get_input_node
20+
21+
22+ @register_node_visitor
23+ class ExpVisitor (NodeVisitor ):
24+ target = "aten.exp.default"
25+
26+ def __init__ (self , * args ) -> None :
27+ super ().__init__ (* args )
28+
29+ def define_node (
30+ self ,
31+ node : torch .fx .Node ,
32+ xnn_graph : XNNGraph ,
33+ vals_to_ids : Dict [torch .fx .Node , int ],
34+ debug_handle : int ,
35+ ) -> None :
36+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37+
38+ # input
39+ input_id = vals_to_ids [get_input_node (node , 0 )]
40+
41+ # output
42+ output_id = vals_to_ids [node ]
43+
44+ ser_node = XNode (
45+ xnode_union = XNNExp (
46+ input_id = input_id ,
47+ output_id = output_id ,
48+ flags = 0 ,
49+ ),
50+ debug_handle = debug_handle ,
51+ )
52+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 2525 ConstantPadConfig ,
2626 DeQuantizedPerTensorConfig ,
2727 DivConfig ,
28+ # EluConfig,
29+ ExpConfig ,
2830 FloorConfig ,
2931 GeluConfig ,
3032 HardswishConfig ,
31- # EluConfig,
3233 HardtanhConfig ,
3334 LeakyReLUConfig ,
3435 LogConfig ,
7980 ClampConfig ,
8081 DivConfig ,
8182 # EluConfig, # Waiting for PyTorch Pin Update
83+ ExpConfig ,
8284 FloorConfig ,
8385 GeluConfig ,
8486 HardtanhConfig ,
Original file line number Diff line number Diff line change @@ -336,6 +336,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336336 return torch .ops .aten .upsample_bilinear2d .vec
337337
338338
339+ class ExpConfig (GenericNodePartitionerConfig ):
340+ target_name = "exp.default"
341+
342+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
343+ return [ConfigPrecisionType .FP32 ]
344+
345+
339346class FloorConfig (GenericNodePartitionerConfig ):
340347 target_name = "floor.default"
341348
Original file line number Diff line number Diff line change 5959 exir_ops .edge .aten .sigmoid .default ,
6060 exir_ops .edge .aten ._softmax .default ,
6161 exir_ops .edge .aten .cat .default ,
62+ exir_ops .edge .aten .exp .default ,
6263 exir_ops .edge .aten .elu .default ,
6364 exir_ops .edge .aten .avg_pool2d .default ,
6465 exir_ops .edge .aten .leaky_relu .default ,
Original file line number Diff line number Diff line change @@ -1723,6 +1723,36 @@ Error defineELUNode(
17231723 return Error::Ok;
17241724}
17251725
1726+ /*
1727+ Define serialized exp node into the subgraph, using the remapped ids
1728+ to map the serialized ids, to the new ids generated when defining the
1729+ tensor value
1730+ */
1731+ Error defineExpNode (
1732+ xnn_subgraph_t subgraph_ptr,
1733+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1734+ const NodePtr node,
1735+ const fb_xnnpack::XNNGraph* graph) noexcept {
1736+ MAYBE_UNUSED (graph);
1737+
1738+ auto graph_node = node->xnode_union_as_XNNExp ();
1739+
1740+ xnn_status status = xnn_define_exp (
1741+ subgraph_ptr,
1742+ remapped_ids.at (graph_node->input_id ()),
1743+ remapped_ids.at (graph_node->output_id ()),
1744+ graph_node->flags ());
1745+
1746+ ET_CHECK_OR_RETURN_ERROR (
1747+ status == xnn_status_success,
1748+ Internal,
1749+ " Failed to create exp node %i with code: %s" ,
1750+ node->debug_handle (),
1751+ xnn_status_to_string (status));
1752+
1753+ return Error::Ok;
1754+ }
1755+
17261756/*
17271757Defines absolute value node into subgraph using the remapped ids to map the
17281758serialized ids to the new ids generated when defining the tensor value
@@ -2082,6 +2112,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
20822112 _DEFINE (Negate)
20832113 _DEFINE (Square)
20842114 _DEFINE (ELU)
2115+ _DEFINE (Exp)
20852116 _DEFINE (Abs)
20862117 _DEFINE (PReLU)
20872118 _DEFINE (Concatenate2)
Original file line number Diff line number Diff line change @@ -132,6 +132,7 @@ union XNodeUnion {
132132 XNNNegate: _XNNNode1x1,
133133 XNNSquare: _XNNNode1x1,
134134 XNNELU,
135+ XNNExp: _XNNNode1x1,
135136 XNNAbs: _XNNNode1x1,
136137 XNNPReLU: _XNNNode2x1,
137138 XNNConcatenate2: _XNNCat,
Original file line number Diff line number Diff line change @@ -128,6 +128,7 @@ union XNodeUnion {
128128 XNNNegate: _XNNNode1x1,
129129 XNNSquare: _XNNNode1x1,
130130 XNNELU,
131+ XNNExp: _XNNNode1x1,
131132 XNNAbs: _XNNNode1x1,
132133 XNNPReLU: _XNNNode2x1,
133134 XNNConcatenate2: _XNNCat,
Original file line number Diff line number Diff line change @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291291 pass
292292
293293
294+ @dataclass
295+ class XNNExp (XNNNode1x1 ):
296+ pass
297+
298+
294299@dataclass
295300class XNNGelu (XNNNode1x1 ):
296301 pass
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import unittest
8+
9+ import torch
10+ from executorch .backends .xnnpack .test .tester import Tester
11+
12+
13+ class TestExp (unittest .TestCase ):
14+ def setUp (self ):
15+ torch ._dynamo .reset ()
16+
17+ class Exp (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+
21+ def forward (self , x ):
22+ return torch .exp (x )
23+
24+ def run_exp_test (self , inputs ):
25+ (
26+ Tester (self .Exp (), inputs )
27+ .export ()
28+ .check_count ({"torch.ops.aten.exp.default" : 1 })
29+ .to_edge_transform_and_lower ()
30+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31+ .check_not (["executorch_exir_dialects_edge__ops_aten_exp_default" ])
32+ .to_executorch ()
33+ .serialize ()
34+ .run_method_and_compare_outputs ()
35+ )
36+
37+ def test_fp16_exp (self ):
38+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39+ self .run_exp_test (inputs )
40+
41+ def test_fp32_exp (self ):
42+ inputs = (torch .randn (20 ),)
43+ self .run_exp_test (inputs )
You can’t perform that action at this time.
0 commit comments