File tree Expand file tree Collapse file tree 8 files changed +95
-0
lines changed
Expand file tree Collapse file tree 8 files changed +95
-0
lines changed Original file line number Diff line number Diff line change 3030 # EluConfig,
3131 HardtanhConfig ,
3232 LeakyReLUConfig ,
33+ LogConfig ,
3334 MaximumConfig ,
3435 MaxPool2dConfig ,
3536 MeanDimConfig ,
8283 HardswishConfig ,
8384 LeakyReLUConfig ,
8485 LinearConfig ,
86+ LogConfig ,
8587 MaxDimConfig ,
8688 MaximumConfig ,
8789 MaxPool2dConfig ,
Original file line number Diff line number Diff line change @@ -357,6 +357,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
357357 return [ConfigPrecisionType .FP32 ]
358358
359359
360+ class LogConfig (GenericNodePartitionerConfig ):
361+ target_name = "log.default"
362+
363+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
364+ return [ConfigPrecisionType .FP32 ]
365+
366+
360367class MeanDimConfig (GenericNodePartitionerConfig ):
361368 target_name = "mean.dim"
362369
Original file line number Diff line number Diff line change 6464 exir_ops .edge .aten .leaky_relu .default ,
6565 exir_ops .edge .aten .addmm .default , # TODO(T163877189) add constraint for addmm
6666 exir_ops .edge .aten .rsqrt .default ,
67+ exir_ops .edge .aten .log .default ,
6768]
6869
6970SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1418,6 +1418,37 @@ Error defineReciprocalSquareRootNode(
14181418 return Error::Ok;
14191419}
14201420
1421+ /*
1422+ Define serialized log node into the subgraph, using the remapped ids
1423+ to map the serialized ids, to the new ids generated when defining the
1424+ tensor value
1425+ */
1426+ Error defineLogNode (
1427+ xnn_subgraph_t subgraph_ptr,
1428+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1429+ const NodePtr node,
1430+ const fb_xnnpack::XNNGraph* graph) noexcept {
1431+ MAYBE_UNUSED (graph);
1432+
1433+ auto graph_node = node->xnode_union_as_XNNLog ();
1434+
1435+ xnn_status status = xnn_define_log (
1436+ subgraph_ptr,
1437+ remapped_ids.at (graph_node->input_id ()),
1438+ remapped_ids.at (graph_node->output_id ()),
1439+ graph_node->flags ());
1440+
1441+ ET_CHECK_OR_RETURN_ERROR (
1442+ status == xnn_status_success,
1443+ Internal,
1444+ " Failed to create log node %i with code: %s" ,
1445+ node->debug_handle (),
1446+ xnn_status_to_string (status));
1447+
1448+ return Error::Ok;
1449+ }
1450+
1451+
14211452/*
14221453Define serialized ceiling node into the subgraph, using the remapped ids
14231454to map the serialized ids, to the new ids generated when defining the
@@ -1981,6 +2012,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
19812012 _DEFINE (Ceiling)
19822013 _DEFINE (Hardswish)
19832014 _DEFINE (LeakyReLU)
2015+ _DEFINE (Log)
19842016 _DEFINE (Maximum)
19852017 _DEFINE (Negate)
19862018 _DEFINE (Square)
Original file line number Diff line number Diff line change @@ -139,6 +139,7 @@ union XNodeUnion {
139139 XNNConcatenate5: _XNNCat,
140140 XNNConvTranspose2d: _XNNNodeConv,
141141 XNNReciprocalSquareRoot: _XNNNode1x1,
142+ XNNLog: _XNNNode1x1,
142143}
143144
144145union XValueUnion {
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ union XNodeUnion {
135135 XNNConcatenate5: _XNNCat,
136136 XNNConvTranspose2d: _XNNNodeConv,
137137 XNNReciprocalSquareRoot: _XNNNode1x1,
138+ XNNLog: _XNNNode1x1,
138139}
139140
140141union XValueUnion {
Original file line number Diff line number Diff line change @@ -309,6 +309,11 @@ class XNNLeakyReLU:
309309 flags : int
310310
311311
312+ @dataclass
313+ class XNNLog (XNNNode1x1 ):
314+ pass
315+
316+
312317@dataclass
313318class XNNMaximum (XNNNode2x1 ):
314319 pass
@@ -379,6 +384,7 @@ class XNNScaledDotProductAttention:
379384 XNNScaledDotProductAttention ,
380385 XNNBatchMatrixMultiply ,
381386 XNNReciprocalSquareRoot ,
387+ XNNLog ,
382388]
383389
384390
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import unittest
8+
9+ import torch
10+ from executorch .backends .xnnpack .test .tester import Tester
11+
12+
13+ class TestLog (unittest .TestCase ):
14+ def setUp (self ):
15+ torch ._dynamo .reset ()
16+
17+ class Log (torch .nn .Module ):
18+ def __init__ (self ):
19+ super ().__init__ ()
20+
21+ def forward (self , x ):
22+ x = torch .abs (x )
23+ z = torch .log (x )
24+ return z
25+
26+ def run_log_test (self , inputs ):
27+ (
28+ Tester (self .Log (), inputs )
29+ .export ()
30+ .check_count ({"torch.ops.aten.log.default" : 1 })
31+ .to_edge_transform_and_lower ()
32+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
33+ .check_not (["executorch_exir_dialects_edge__ops_aten_log_default" ])
34+ .to_executorch ()
35+ .serialize ()
36+ .run_method_and_compare_outputs ()
37+ )
38+
39+ def test_fp16_log (self ):
40+ inputs = (torch .randn (20 ).to (torch .float16 ),)
41+ self .run_log_test (inputs )
42+
43+ def test_fp32_log (self ):
44+ inputs = (torch .randn (20 ),)
45+ self .run_log_test (inputs )
You can’t perform that action at this time.
0 commit comments