File tree Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Original file line number Diff line number Diff line change 5151 op_static_constant_pad ,
5252 op_static_resize_bilinear_2d ,
5353 op_sub ,
54+ op_tanh ,
5455 op_to_copy ,
5556)
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ from typing import Dict
8+
9+ import torch
10+ from executorch .backends .xnnpack .operators .node_visitor import (
11+ NodeVisitor ,
12+ register_node_visitor ,
13+ )
14+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15+ XNNGraph ,
16+ XNNTanh ,
17+ XNode ,
18+ )
19+ from executorch .backends .xnnpack .utils .utils import get_input_node
20+
21+
22+ @register_node_visitor
23+ class TanhVisitor (NodeVisitor ):
24+ target = "aten.tanh.default"
25+
26+ def __init__ (self , * args ) -> None :
27+ super ().__init__ (* args )
28+
29+ def define_node (
30+ self ,
31+ node : torch .fx .Node ,
32+ xnn_graph : XNNGraph ,
33+ vals_to_ids : Dict [torch .fx .Node , int ],
34+ debug_handle : int ,
35+ ) -> None :
36+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37+
38+ # input
39+ input_id = vals_to_ids [get_input_node (node , 0 )]
40+
41+ # output
42+ output_id = vals_to_ids [node ]
43+
44+ ser_node = XNode (
45+ xnode_union = XNNTanh (
46+ input_id = input_id ,
47+ output_id = output_id ,
48+ flags = 0 ,
49+ ),
50+ debug_handle = debug_handle ,
51+ )
52+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 5050 SoftmaxConfig ,
5151 SquareRootConfig ,
5252 SubConfig ,
53+ TanhConfig ,
5354 UpsampleBilinear2dConfig ,
5455)
5556from executorch .backends .xnnpack .partition .config .node_configs import (
101102 PreluConfig ,
102103 ReciprocalSquareRootConfig ,
103104 ReLUConfig ,
105+ TanhConfig ,
104106 # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
105107 SigmoidConfig ,
106108 SliceCopyConfig ,
Original file line number Diff line number Diff line change @@ -378,6 +378,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
378378 return [ConfigPrecisionType .FP32 ]
379379
380380
381+ class TanhConfig (GenericNodePartitionerConfig ):
382+ target_name = "tanh.default"
383+
384+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
385+ return [ConfigPrecisionType .FP32 ]
386+
387+
381388class MeanDimConfig (GenericNodePartitionerConfig ):
382389 target_name = "mean.dim"
383390
Original file line number Diff line number Diff line change 6767 exir_ops .edge .aten .rsqrt .default ,
6868 exir_ops .edge .aten .log .default ,
6969 exir_ops .edge .aten .gelu .default ,
70+ exir_ops .edge .aten .tanh .default ,
7071]
7172
7273SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1513,6 +1513,36 @@ Error defineGeluNode(
15131513 return Error::Ok;
15141514}
15151515
1516+ /*
1517+ Define serialized tanh node into the subgraph, using the remapped ids
1518+ to map the serialized ids, to the new ids generated when defining the
1519+ tensor value
1520+ */
1521+ Error defineTanhNode (
1522+ xnn_subgraph_t subgraph_ptr,
1523+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1524+ const NodePtr node,
1525+ const fb_xnnpack::XNNGraph* graph) noexcept {
1526+ MAYBE_UNUSED (graph);
1527+
1528+ auto graph_node = node->xnode_union_as_XNNTanh ();
1529+
1530+ xnn_status status = xnn_define_tanh (
1531+ subgraph_ptr,
1532+ remapped_ids.at (graph_node->input_id ()),
1533+ remapped_ids.at (graph_node->output_id ()),
1534+ graph_node->flags ());
1535+
1536+ ET_CHECK_OR_RETURN_ERROR (
1537+ status == xnn_status_success,
1538+ Internal,
1539+ " Failed to create tanh node %i with code: %s" ,
1540+ node->debug_handle (),
1541+ xnn_status_to_string (status));
1542+
1543+ return Error::Ok;
1544+ }
1545+
15161546/*
15171547Define serialized ceiling node into the subgraph, using the remapped ids
15181548to map the serialized ids, to the new ids generated when defining the
@@ -2108,6 +2138,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
21082138 _DEFINE (Hardswish)
21092139 _DEFINE (LeakyReLU)
21102140 _DEFINE (Log)
2141+ _DEFINE (Tanh)
21112142 _DEFINE (Maximum)
21122143 _DEFINE (Negate)
21132144 _DEFINE (Square)
Original file line number Diff line number Diff line change @@ -146,6 +146,7 @@ union XNodeUnion {
146146 XNNReciprocalSquareRoot: _XNNNode1x1,
147147 XNNLog: _XNNNode1x1,
148148 XNNGelu: _XNNNode1x1,
149+ XNNTanh: _XNNNode1x1,
149150}
150151
151152union XValueUnion {
Original file line number Diff line number Diff line change @@ -142,6 +142,7 @@ union XNodeUnion {
142142 XNNReciprocalSquareRoot: _XNNNode1x1,
143143 XNNLog: _XNNNode1x1,
144144 XNNGelu: _XNNNode1x1,
145+ XNNTanh: _XNNNode1x1,
145146}
146147
147148union XValueUnion {
Original file line number Diff line number Diff line change @@ -324,6 +324,11 @@ class XNNLog(XNNNode1x1):
324324 pass
325325
326326
327+ @dataclass
328+ class XNNTanh (XNNNode1x1 ):
329+ pass
330+
331+
327332@dataclass
328333class XNNMaximum (XNNNode2x1 ):
329334 pass
@@ -396,6 +401,7 @@ class XNNScaledDotProductAttention:
396401 XNNReciprocalSquareRoot ,
397402 XNNLog ,
398403 XNNGelu ,
404+ XNNTanh ,
399405]
400406
401407
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import unittest
8+
9+ import torch
10+ from executorch .backends .xnnpack .test .tester import Tester
11+
12+
13+ class TestTanh (unittest .TestCase ):
14+ def setUp (self ):
15+ torch ._dynamo .reset ()
16+
17+ class Tanh (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+
21+ def forward (self , x ):
22+ return torch .tanh (x )
23+
24+ def run_tanh_test (self , inputs ):
25+ (
26+ Tester (self .Tanh (), inputs )
27+ .export ()
28+ .check_count ({"torch.ops.aten.tanh.default" : 1 })
29+ .to_edge_transform_and_lower ()
30+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31+ .check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
32+ .to_executorch ()
33+ .serialize ()
34+ .run_method_and_compare_outputs ()
35+ )
36+
37+ def test_fp16_tanh (self ):
38+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39+ self .run_tanh_test (inputs )
40+
41+ def test_fp32_tanh (self ):
42+ inputs = (torch .randn (20 ),)
43+ self .run_tanh_test (inputs )
You can’t perform that action at this time.
0 commit comments