File tree Expand file tree Collapse file tree 10 files changed +147
-0
lines changed Expand file tree Collapse file tree 10 files changed +147
-0
lines changed Original file line number Diff line number Diff line change 2424 op_hardtanh ,
2525 op_leaky_relu ,
2626 op_linear ,
27+ op_log ,
2728 op_matrix_multiplication ,
2829 op_max_dim ,
2930 op_max_pool2d ,
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ from typing import Dict
8+
9+ import torch
10+ from executorch .backends .xnnpack .operators .node_visitor import (
11+ NodeVisitor ,
12+ register_node_visitor ,
13+ )
14+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15+ XNNGraph ,
16+ XNNLog ,
17+ XNode ,
18+ )
19+ from executorch .backends .xnnpack .utils .utils import get_input_node
20+
21+
22+ @register_node_visitor
23+ class LogVisitor (NodeVisitor ):
24+ target = "aten.log.default"
25+
26+ def __init__ (self , * args ) -> None :
27+ super ().__init__ (* args )
28+
29+ def define_node (
30+ self ,
31+ node : torch .fx .Node ,
32+ xnn_graph : XNNGraph ,
33+ vals_to_ids : Dict [torch .fx .Node , int ],
34+ debug_handle : int ,
35+ ) -> None :
36+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37+
38+ # input
39+ input_id = vals_to_ids [get_input_node (node , 0 )]
40+
41+ # output
42+ output_id = vals_to_ids [node ]
43+
44+ ser_node = XNode (
45+ xnode_union = XNNLog (
46+ input_id = input_id ,
47+ output_id = output_id ,
48+ flags = 0 ,
49+ ),
50+ debug_handle = debug_handle ,
51+ )
52+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 3030 # EluConfig,
3131 HardtanhConfig ,
3232 LeakyReLUConfig ,
33+ LogConfig ,
3334 MaximumConfig ,
3435 MaxPool2dConfig ,
3536 MeanDimConfig ,
8283 HardswishConfig ,
8384 LeakyReLUConfig ,
8485 LinearConfig ,
86+ LogConfig ,
8587 MaxDimConfig ,
8688 MaximumConfig ,
8789 MaxPool2dConfig ,
Original file line number Diff line number Diff line change @@ -357,6 +357,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
357357 return [ConfigPrecisionType .FP32 ]
358358
359359
360+ class LogConfig (GenericNodePartitionerConfig ):
361+ target_name = "log.default"
362+
363+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
364+ return [ConfigPrecisionType .FP32 ]
365+
366+
360367class MeanDimConfig (GenericNodePartitionerConfig ):
361368 target_name = "mean.dim"
362369
Original file line number Diff line number Diff line change 6464 exir_ops .edge .aten .leaky_relu .default ,
6565 exir_ops .edge .aten .addmm .default , # TODO(T163877189) add constraint for addmm
6666 exir_ops .edge .aten .rsqrt .default ,
67+ exir_ops .edge .aten .log .default ,
6768]
6869
6970SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1418,6 +1418,36 @@ Error defineReciprocalSquareRootNode(
14181418 return Error::Ok;
14191419}
14201420
1421+ /*
1422+ Define serialized log node into the subgraph, using the remapped ids
1423+ to map the serialized ids, to the new ids generated when defining the
1424+ tensor value
1425+ */
1426+ Error defineLogNode (
1427+ xnn_subgraph_t subgraph_ptr,
1428+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1429+ const NodePtr node,
1430+ const fb_xnnpack::XNNGraph* graph) noexcept {
1431+ MAYBE_UNUSED (graph);
1432+
1433+ auto graph_node = node->xnode_union_as_XNNLog ();
1434+
1435+ xnn_status status = xnn_define_log (
1436+ subgraph_ptr,
1437+ remapped_ids.at (graph_node->input_id ()),
1438+ remapped_ids.at (graph_node->output_id ()),
1439+ graph_node->flags ());
1440+
1441+ ET_CHECK_OR_RETURN_ERROR (
1442+ status == xnn_status_success,
1443+ Internal,
1444+ " Failed to create log node %i with code: %s" ,
1445+ node->debug_handle (),
1446+ xnn_status_to_string (status));
1447+
1448+ return Error::Ok;
1449+ }
1450+
14211451/*
14221452Define serialized ceiling node into the subgraph, using the remapped ids
14231453to map the serialized ids, to the new ids generated when defining the
@@ -1981,6 +2011,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
19812011 _DEFINE (Ceiling)
19822012 _DEFINE (Hardswish)
19832013 _DEFINE (LeakyReLU)
2014+ _DEFINE (Log)
19842015 _DEFINE (Maximum)
19852016 _DEFINE (Negate)
19862017 _DEFINE (Square)
Original file line number Diff line number Diff line change @@ -139,6 +139,7 @@ union XNodeUnion {
139139 XNNConcatenate5: _XNNCat,
140140 XNNConvTranspose2d: _XNNNodeConv,
141141 XNNReciprocalSquareRoot: _XNNNode1x1,
142+ XNNLog: _XNNNode1x1,
142143}
143144
144145union XValueUnion {
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ union XNodeUnion {
135135 XNNConcatenate5: _XNNCat,
136136 XNNConvTranspose2d: _XNNNodeConv,
137137 XNNReciprocalSquareRoot: _XNNNode1x1,
138+ XNNLog: _XNNNode1x1,
138139}
139140
140141union XValueUnion {
Original file line number Diff line number Diff line change @@ -309,6 +309,11 @@ class XNNLeakyReLU:
309309 flags : int
310310
311311
312+ @dataclass
313+ class XNNLog (XNNNode1x1 ):
314+ pass
315+
316+
312317@dataclass
313318class XNNMaximum (XNNNode2x1 ):
314319 pass
@@ -379,6 +384,7 @@ class XNNScaledDotProductAttention:
379384 XNNScaledDotProductAttention ,
380385 XNNBatchMatrixMultiply ,
381386 XNNReciprocalSquareRoot ,
387+ XNNLog ,
382388]
383389
384390
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import unittest
8+
9+ import torch
10+ from executorch .backends .xnnpack .test .tester import Tester
11+
12+
13+ class TestLog (unittest .TestCase ):
14+ def setUp (self ):
15+ torch ._dynamo .reset ()
16+
17+ class Log (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+
21+ def forward (self , x ):
22+ x = torch .abs (x )
23+ z = torch .log (x )
24+ return z
25+
26+ def run_log_test (self , inputs ):
27+ (
28+ Tester (self .Log (), inputs )
29+ .export ()
30+ .check_count ({"torch.ops.aten.log.default" : 1 })
31+ .to_edge_transform_and_lower ()
32+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
33+ .check_not (["executorch_exir_dialects_edge__ops_aten_log_default" ])
34+ .to_executorch ()
35+ .serialize ()
36+ .run_method_and_compare_outputs ()
37+ )
38+
39+ def test_fp16_log (self ):
40+ inputs = (torch .randn (20 ).to (torch .float16 ),)
41+ self .run_log_test (inputs )
42+
43+ def test_fp32_log (self ):
44+ inputs = (torch .randn (20 ),)
45+ self .run_log_test (inputs )
You can’t perform that action at this time.
0 commit comments