File tree Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Expand file tree Collapse file tree 10 files changed +145
-1
lines changed Original file line number Diff line number Diff line change 19
19
op_dynamic_dequantize_ops ,
20
20
op_dynamic_quantize_ops ,
21
21
op_elu ,
22
+ op_exp ,
22
23
op_floor ,
23
24
op_gelu ,
24
25
op_hardswish ,
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNExp ,
16
+ XNNGraph ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class ExpVisitor (NodeVisitor ):
24
+ target = "aten.exp.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNExp (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 25
25
ConstantPadConfig ,
26
26
DeQuantizedPerTensorConfig ,
27
27
DivConfig ,
28
+ # EluConfig,
29
+ ExpConfig ,
28
30
FloorConfig ,
29
31
GeluConfig ,
30
32
HardswishConfig ,
31
- # EluConfig,
32
33
HardtanhConfig ,
33
34
LeakyReLUConfig ,
34
35
LogConfig ,
80
81
ClampConfig ,
81
82
DivConfig ,
82
83
# EluConfig, # Waiting for PyTorch Pin Update
84
+ ExpConfig ,
83
85
FloorConfig ,
84
86
GeluConfig ,
85
87
HardtanhConfig ,
Original file line number Diff line number Diff line change @@ -348,6 +348,13 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
348
348
return torch .ops .aten .upsample_bilinear2d .vec
349
349
350
350
351
+ class ExpConfig (GenericNodePartitionerConfig ):
352
+ target_name = "exp.default"
353
+
354
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
355
+ return [ConfigPrecisionType .FP32 ]
356
+
357
+
351
358
class FloorConfig (GenericNodePartitionerConfig ):
352
359
target_name = "floor.default"
353
360
Original file line number Diff line number Diff line change 67
67
exir_ops .edge .aten .log .default ,
68
68
exir_ops .edge .aten .gelu .default ,
69
69
exir_ops .edge .aten .tanh .default ,
70
+ exir_ops .edge .aten .exp .default ,
70
71
]
71
72
72
73
SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode(
1162
1162
return Error::Ok;
1163
1163
}
1164
1164
1165
+ /*
1166
+ Define serialized exp node into the subgraph, using the remapped ids
1167
+ to map the serialized ids, to the new ids generated when defining the
1168
+ tensor value
1169
+ */
1170
+ Error defineExpNode (
1171
+ xnn_subgraph_t subgraph_ptr,
1172
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1173
+ const NodePtr node,
1174
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1175
+ MAYBE_UNUSED (graph);
1176
+
1177
+ auto graph_node = node->xnode_union_as_XNNExp ();
1178
+
1179
+ xnn_status status = xnn_define_exp (
1180
+ subgraph_ptr,
1181
+ remapped_ids.at (graph_node->input_id ()),
1182
+ remapped_ids.at (graph_node->output_id ()),
1183
+ graph_node->flags ());
1184
+
1185
+ ET_CHECK_OR_RETURN_ERROR (
1186
+ status == xnn_status_success,
1187
+ Internal,
1188
+ " Failed to create exp node %i with code: %s" ,
1189
+ node->debug_handle (),
1190
+ xnn_status_to_string (status));
1191
+
1192
+ return Error::Ok;
1193
+ }
1194
+
1165
1195
/*
1166
1196
Define serialized tanh node into the subgraph, using the remapped ids
1167
1197
to map the serialized ids, to the new ids generated when defining the
@@ -1733,6 +1763,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1733
1763
_DEFINE (Clamp)
1734
1764
_DEFINE (LeakyReLU)
1735
1765
_DEFINE (ELU)
1766
+ _DEFINE (Exp)
1736
1767
_DEFINE (Abs)
1737
1768
_DEFINE (Floor)
1738
1769
_DEFINE (PReLU)
Original file line number Diff line number Diff line change @@ -155,6 +155,7 @@ union XNodeUnion {
155
155
XNNLog: _XNNNode1x1,
156
156
XNNGelu: _XNNNode1x1,
157
157
XNNTanh: _XNNNode1x1,
158
+ XNNExp: _XNNNode1x1,
158
159
}
159
160
160
161
union XValueUnion {
Original file line number Diff line number Diff line change @@ -151,6 +151,7 @@ union XNodeUnion {
151
151
XNNLog: _XNNNode1x1,
152
152
XNNGelu: _XNNNode1x1,
153
153
XNNTanh: _XNNNode1x1,
154
+ XNNExp: _XNNNode1x1,
154
155
}
155
156
156
157
union XValueUnion {
Original file line number Diff line number Diff line change @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291
291
pass
292
292
293
293
294
+ @dataclass
295
+ class XNNExp (XNNNode1x1 ):
296
+ pass
297
+
298
+
294
299
@dataclass
295
300
class XNNGelu (XNNNode1x1 ):
296
301
pass
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestExp (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Exp (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+
21
+ def forward (self , x ):
22
+ return torch .exp (x )
23
+
24
+ def run_exp_test (self , inputs ):
25
+ (
26
+ Tester (self .Exp (), inputs )
27
+ .export ()
28
+ .check_count ({"torch.ops.aten.exp.default" : 1 })
29
+ .to_edge_transform_and_lower ()
30
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31
+ .check_not (["executorch_exir_dialects_edge__ops_aten_exp_default" ])
32
+ .to_executorch ()
33
+ .serialize ()
34
+ .run_method_and_compare_outputs ()
35
+ )
36
+
37
+ def test_fp16_exp (self ):
38
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39
+ self .run_exp_test (inputs )
40
+
41
+ def test_fp32_exp (self ):
42
+ inputs = (torch .randn (20 ),)
43
+ self .run_exp_test (inputs )
You can’t perform that action at this time.
0 commit comments