File tree Expand file tree Collapse file tree 10 files changed +146
-0
lines changed
Expand file tree Collapse file tree 10 files changed +146
-0
lines changed Original file line number Diff line number Diff line change 2020 op_dynamic_quantize_ops ,
2121 op_elu ,
2222 op_floor ,
23+ op_gelu ,
2324 op_hardswish ,
2425 op_hardtanh ,
2526 op_leaky_relu ,
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ from typing import Dict
8+
9+ import torch
10+ from executorch .backends .xnnpack .operators .node_visitor import (
11+ NodeVisitor ,
12+ register_node_visitor ,
13+ )
14+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15+ XNNGelu ,
16+ XNNGraph ,
17+ XNode ,
18+ )
19+ from executorch .backends .xnnpack .utils .utils import get_input_node
20+
21+
22+ @register_node_visitor
23+ class GeluVisitor (NodeVisitor ):
24+ target = "aten.gelu.default"
25+
26+ def __init__ (self , * args ) -> None :
27+ super ().__init__ (* args )
28+
29+ def define_node (
30+ self ,
31+ node : torch .fx .Node ,
32+ xnn_graph : XNNGraph ,
33+ vals_to_ids : Dict [torch .fx .Node , int ],
34+ debug_handle : int ,
35+ ) -> None :
36+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37+
38+ # input
39+ input_id = vals_to_ids [get_input_node (node , 0 )]
40+
41+ # output
42+ output_id = vals_to_ids [node ]
43+
44+ ser_node = XNode (
45+ xnode_union = XNNGelu (
46+ input_id = input_id ,
47+ output_id = output_id ,
48+ flags = 0 ,
49+ ),
50+ debug_handle = debug_handle ,
51+ )
52+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 2626 DeQuantizedPerTensorConfig ,
2727 DivConfig ,
2828 FloorConfig ,
29+ GeluConfig ,
2930 HardswishConfig ,
3031 # EluConfig,
3132 HardtanhConfig ,
7980 DivConfig ,
8081 # EluConfig, # Waiting for PyTorch Pin Update
8182 FloorConfig ,
83+ GeluConfig ,
8284 HardtanhConfig ,
8385 HardswishConfig ,
8486 LeakyReLUConfig ,
Original file line number Diff line number Diff line change @@ -343,6 +343,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
343343 return [ConfigPrecisionType .FP32 ]
344344
345345
346+ class GeluConfig (GenericNodePartitionerConfig ):
347+ target_name = "gelu.default"
348+
349+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
350+ return [ConfigPrecisionType .FP32 ]
351+
352+
346353class HardswishConfig (GenericNodePartitionerConfig ):
347354 target_name = "hardswish.default"
348355
Original file line number Diff line number Diff line change 6565 exir_ops .edge .aten .addmm .default , # TODO(T163877189) add constraint for addmm
6666 exir_ops .edge .aten .rsqrt .default ,
6767 exir_ops .edge .aten .log .default ,
68+ exir_ops .edge .aten .gelu .default ,
6869]
6970
7071SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1448,6 +1448,36 @@ Error defineLogNode(
14481448 return Error::Ok;
14491449}
14501450
1451+ /*
1452+ Define serialized gelu node into the subgraph, using the remapped ids
1453+ to map the serialized ids, to the new ids generated when defining the
1454+ tensor value
1455+ */
1456+ Error defineGeluNode (
1457+ xnn_subgraph_t subgraph_ptr,
1458+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1459+ const NodePtr node,
1460+ const fb_xnnpack::XNNGraph* graph) noexcept {
1461+ MAYBE_UNUSED (graph);
1462+
1463+ auto graph_node = node->xnode_union_as_XNNGelu ();
1464+
1465+ xnn_status status = xnn_define_gelu (
1466+ subgraph_ptr,
1467+ remapped_ids.at (graph_node->input_id ()),
1468+ remapped_ids.at (graph_node->output_id ()),
1469+ graph_node->flags ());
1470+
1471+ ET_CHECK_OR_RETURN_ERROR (
1472+ status == xnn_status_success,
1473+ Internal,
1474+ " Failed to create gelu node %i with code: %s" ,
1475+ node->debug_handle (),
1476+ xnn_status_to_string (status));
1477+
1478+ return Error::Ok;
1479+ }
1480+
14511481/*
14521482Define serialized ceiling node into the subgraph, using the remapped ids
14531483to map the serialized ids, to the new ids generated when defining the
@@ -2009,6 +2039,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
20092039 _DEFINE (SquareRoot)
20102040 _DEFINE (ReciprocalSquareRoot)
20112041 _DEFINE (Ceiling)
2042+ _DEFINE (Gelu)
20122043 _DEFINE (Hardswish)
20132044 _DEFINE (LeakyReLU)
20142045 _DEFINE (Log)
Original file line number Diff line number Diff line change @@ -140,6 +140,7 @@ union XNodeUnion {
140140 XNNConvTranspose2d: _XNNNodeConv,
141141 XNNReciprocalSquareRoot: _XNNNode1x1,
142142 XNNLog: _XNNNode1x1,
143+ XNNGelu: _XNNNode1x1,
143144}
144145
145146union XValueUnion {
Original file line number Diff line number Diff line change @@ -136,6 +136,7 @@ union XNodeUnion {
136136 XNNConvTranspose2d: _XNNNodeConv,
137137 XNNReciprocalSquareRoot: _XNNNode1x1,
138138 XNNLog: _XNNNode1x1,
139+ XNNGelu: _XNNNode1x1,
139140}
140141
141142union XValueUnion {
Original file line number Diff line number Diff line change @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291291 pass
292292
293293
294+ @dataclass
295+ class XNNGelu (XNNNode1x1 ):
296+ pass
297+
298+
294299@dataclass
295300class XNNHardswish (XNNNode1x1 ):
296301 pass
@@ -385,6 +390,7 @@ class XNNScaledDotProductAttention:
385390 XNNBatchMatrixMultiply ,
386391 XNNReciprocalSquareRoot ,
387392 XNNLog ,
393+ XNNGelu ,
388394]
389395
390396
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import unittest
8+
9+ import torch
10+ from executorch .backends .xnnpack .test .tester import Tester
11+
12+
13+ class TestGelu (unittest .TestCase ):
14+ def setUp (self ):
15+ torch ._dynamo .reset ()
16+
17+ class Gelu (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+ self .gelu = torch .nn .GELU ()
21+
22+ def forward (self , x ):
23+ return self .gelu (x )
24+
25+ def run_gelu_test (self , inputs ):
26+ (
27+ Tester (self .Gelu (), inputs )
28+ .export ()
29+ .check_count ({"torch.ops.aten.gelu.default" : 1 })
30+ .to_edge_transform_and_lower ()
31+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
32+ .check_not (["executorch_exir_dialects_edge__ops_aten_gelu_default" ])
33+ .to_executorch ()
34+ .serialize ()
35+ .run_method_and_compare_outputs ()
36+ )
37+
38+ def test_fp16_gelu (self ):
39+ inputs = (torch .randn (20 ).to (torch .float16 ),)
40+ self .run_gelu_test (inputs )
41+
42+ def test_fp32_gelu (self ):
43+ inputs = (torch .randn (20 ),)
44+ self .run_gelu_test (inputs )
You can’t perform that action at this time.
0 commit comments