Skip to content

Commit 92a98f4

Browse files
Use wrappers from xnnpack.h for unary and binary ops (#11584)
1 parent 289acbd commit 92a98f4

File tree

1 file changed

+87
-32
lines changed

1 file changed

+87
-32
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/xnnpack/serialization/schema_generated.h>
1212
#include <executorch/extension/threadpool/threadpool.h>
1313
#include <executorch/runtime/executor/pte_data_map.h>
14+
#include <xnnpack.h>
1415
#include <string>
1516
#include <unordered_map>
1617
#include <vector>
@@ -615,10 +616,14 @@ Error defineAddNode(
615616

616617
std::pair<float, float> min_max = getOutputMinMax(node);
617618
auto graph_node = node->xnode_union_as_XNNAdd();
618-
xnn_status status = xnn_define_add2(
619+
620+
struct xnn_binary_params params = {
621+
.output_min = min_max.first, .output_max = min_max.second};
622+
623+
xnn_status status = xnn_define_binary(
619624
subgraph_ptr,
620-
min_max.first,
621-
min_max.second,
625+
xnn_binary_add,
626+
&params,
622627
remapped_ids.at(graph_node->input1_id()),
623628
remapped_ids.at(graph_node->input2_id()),
624629
remapped_ids.at(graph_node->output_id()),
@@ -644,8 +649,11 @@ Error defineMinimumNode(
644649
MAYBE_UNUSED(graph);
645650

646651
auto graph_node = node->xnode_union_as_XNNMinimum();
647-
xnn_status status = xnn_define_minimum2(
652+
653+
xnn_status status = xnn_define_binary(
648654
subgraph_ptr,
655+
xnn_binary_minimum,
656+
nullptr,
649657
remapped_ids.at(graph_node->input1_id()),
650658
remapped_ids.at(graph_node->input2_id()),
651659
remapped_ids.at(graph_node->output_id()),
@@ -673,10 +681,14 @@ Error defineSubtractNode(
673681

674682
auto graph_node = node->xnode_union_as_XNNSubtract();
675683
std::pair<float, float> min_max = getOutputMinMax(node);
676-
xnn_status status = xnn_define_subtract(
684+
685+
struct xnn_binary_params params = {
686+
.output_min = min_max.first, .output_max = min_max.second};
687+
688+
xnn_status status = xnn_define_binary(
677689
subgraph_ptr,
678-
min_max.first,
679-
min_max.second,
690+
xnn_binary_subtract,
691+
&params,
680692
remapped_ids.at(graph_node->input1_id()),
681693
remapped_ids.at(graph_node->input2_id()),
682694
remapped_ids.at(graph_node->output_id()),
@@ -704,10 +716,14 @@ Error defineMultiplyNode(
704716

705717
auto graph_node = node->xnode_union_as_XNNMultiply();
706718
std::pair<float, float> min_max = getOutputMinMax(node);
707-
xnn_status status = xnn_define_multiply2(
719+
720+
struct xnn_binary_params params = {
721+
.output_min = min_max.first, .output_max = min_max.second};
722+
723+
xnn_status status = xnn_define_binary(
708724
subgraph_ptr,
709-
min_max.first,
710-
min_max.second,
725+
xnn_binary_multiply,
726+
&params,
711727
remapped_ids.at(graph_node->input1_id()),
712728
remapped_ids.at(graph_node->input2_id()),
713729
remapped_ids.at(graph_node->output_id()),
@@ -857,10 +873,14 @@ Error defineClampNode(
857873

858874
std::pair<float, float> min_max = getOutputMinMax(node);
859875
auto graph_node = node->xnode_union_as_XNNClamp();
860-
xnn_status status = xnn_define_clamp(
876+
877+
union xnn_unary_params params = {
878+
.clamp = {.min = min_max.first, .max = min_max.second}};
879+
880+
xnn_status status = xnn_define_unary(
861881
subgraph_ptr,
862-
min_max.first,
863-
min_max.second,
882+
xnn_unary_clamp,
883+
&params,
864884
remapped_ids.at(graph_node->input_id()),
865885
remapped_ids.at(graph_node->output_id()),
866886
graph_node->flags());
@@ -916,8 +936,10 @@ Error defineSigmoidNode(
916936
MAYBE_UNUSED(graph);
917937

918938
auto graph_node = node->xnode_union_as_XNNSigmoid();
919-
xnn_status status = xnn_define_sigmoid(
939+
xnn_status status = xnn_define_unary(
920940
subgraph_ptr,
941+
xnn_unary_sigmoid,
942+
nullptr,
921943
remapped_ids.at(graph_node->input_id()),
922944
remapped_ids.at(graph_node->output_id()),
923945
graph_node->flags());
@@ -944,8 +966,10 @@ Error defineFloorNode(
944966
MAYBE_UNUSED(graph);
945967

946968
auto graph_node = node->xnode_union_as_XNNFloor();
947-
xnn_status status = xnn_define_floor(
969+
xnn_status status = xnn_define_unary(
948970
subgraph_ptr,
971+
xnn_unary_floor,
972+
nullptr,
949973
remapped_ids.at(graph_node->input_id()),
950974
remapped_ids.at(graph_node->output_id()),
951975
graph_node->flags());
@@ -1167,10 +1191,14 @@ Error defineDivNode(
11671191

11681192
auto graph_node = node->xnode_union_as_XNNDiv();
11691193
std::pair<float, float> min_max = getOutputMinMax(node);
1170-
xnn_status status = xnn_define_divide(
1194+
1195+
struct xnn_binary_params params = {
1196+
.output_min = min_max.first, .output_max = min_max.second};
1197+
1198+
xnn_status status = xnn_define_binary(
11711199
subgraph_ptr,
1172-
min_max.first,
1173-
min_max.second,
1200+
xnn_binary_divide,
1201+
&params,
11741202
remapped_ids.at(graph_node->input1_id()),
11751203
remapped_ids.at(graph_node->input2_id()),
11761204
remapped_ids.at(graph_node->output_id()),
@@ -1415,8 +1443,10 @@ Error defineSquareRootNode(
14151443

14161444
auto graph_node = node->xnode_union_as_XNNSquareRoot();
14171445

1418-
xnn_status status = xnn_define_square_root(
1446+
xnn_status status = xnn_define_unary(
14191447
subgraph_ptr,
1448+
xnn_unary_square_root,
1449+
nullptr,
14201450
remapped_ids.at(graph_node->input_id()),
14211451
remapped_ids.at(graph_node->output_id()),
14221452
graph_node->flags());
@@ -1445,8 +1475,10 @@ Error defineReciprocalSquareRootNode(
14451475

14461476
auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot();
14471477

1448-
xnn_status status = xnn_define_reciprocal_square_root(
1478+
xnn_status status = xnn_define_unary(
14491479
subgraph_ptr,
1480+
xnn_unary_reciprocal_square_root,
1481+
nullptr,
14501482
remapped_ids.at(graph_node->input_id()),
14511483
remapped_ids.at(graph_node->output_id()),
14521484
graph_node->flags());
@@ -1475,8 +1507,10 @@ Error defineLogNode(
14751507

14761508
auto graph_node = node->xnode_union_as_XNNLog();
14771509

1478-
xnn_status status = xnn_define_log(
1510+
xnn_status status = xnn_define_unary(
14791511
subgraph_ptr,
1512+
xnn_unary_log,
1513+
nullptr,
14801514
remapped_ids.at(graph_node->input_id()),
14811515
remapped_ids.at(graph_node->output_id()),
14821516
graph_node->flags());
@@ -1505,8 +1539,10 @@ Error defineGeluNode(
15051539

15061540
auto graph_node = node->xnode_union_as_XNNGelu();
15071541

1508-
xnn_status status = xnn_define_gelu(
1542+
xnn_status status = xnn_define_unary(
15091543
subgraph_ptr,
1544+
xnn_unary_gelu,
1545+
nullptr,
15101546
remapped_ids.at(graph_node->input_id()),
15111547
remapped_ids.at(graph_node->output_id()),
15121548
graph_node->flags());
@@ -1535,8 +1571,10 @@ Error defineCeilingNode(
15351571

15361572
auto graph_node = node->xnode_union_as_XNNCeiling();
15371573

1538-
xnn_status status = xnn_define_ceiling(
1574+
xnn_status status = xnn_define_unary(
15391575
subgraph_ptr,
1576+
xnn_unary_ceiling,
1577+
nullptr,
15401578
remapped_ids.at(graph_node->input_id()),
15411579
remapped_ids.at(graph_node->output_id()),
15421580
graph_node->flags());
@@ -1565,8 +1603,10 @@ Error defineHardswishNode(
15651603

15661604
auto graph_node = node->xnode_union_as_XNNHardswish();
15671605

1568-
xnn_status status = xnn_define_hardswish(
1606+
xnn_status status = xnn_define_unary(
15691607
subgraph_ptr,
1608+
xnn_unary_hardswish,
1609+
nullptr,
15701610
remapped_ids.at(graph_node->input_id()),
15711611
remapped_ids.at(graph_node->output_id()),
15721612
graph_node->flags());
@@ -1595,9 +1635,13 @@ Error defineLeakyReLUNode(
15951635

15961636
auto graph_node = node->xnode_union_as_XNNLeakyReLU();
15971637

1598-
xnn_status status = xnn_define_leaky_relu(
1638+
union xnn_unary_params params = {
1639+
.leaky_relu = {.negative_slope = graph_node->negative_slope()}};
1640+
1641+
xnn_status status = xnn_define_unary(
15991642
subgraph_ptr,
1600-
graph_node->negative_slope(),
1643+
xnn_unary_leaky_relu,
1644+
&params,
16011645
remapped_ids.at(graph_node->input_id()),
16021646
remapped_ids.at(graph_node->output_id()),
16031647
graph_node->flags());
@@ -1626,8 +1670,10 @@ Error defineMaximumNode(
16261670

16271671
auto graph_node = node->xnode_union_as_XNNMaximum();
16281672

1629-
xnn_status status = xnn_define_maximum2(
1673+
xnn_status status = xnn_define_binary(
16301674
subgraph_ptr,
1675+
xnn_binary_maximum,
1676+
nullptr,
16311677
remapped_ids.at(graph_node->input1_id()),
16321678
remapped_ids.at(graph_node->input2_id()),
16331679
remapped_ids.at(graph_node->output_id()),
@@ -1656,8 +1702,10 @@ Error defineNegateNode(
16561702

16571703
auto graph_node = node->xnode_union_as_XNNNegate();
16581704

1659-
xnn_status status = xnn_define_negate(
1705+
xnn_status status = xnn_define_unary(
16601706
subgraph_ptr,
1707+
xnn_unary_negate,
1708+
nullptr,
16611709
remapped_ids.at(graph_node->input_id()),
16621710
remapped_ids.at(graph_node->output_id()),
16631711
graph_node->flags());
@@ -1685,8 +1733,10 @@ Error defineSquareNode(
16851733

16861734
auto graph_node = node->xnode_union_as_XNNSquare();
16871735

1688-
xnn_status status = xnn_define_square(
1736+
xnn_status status = xnn_define_unary(
16891737
subgraph_ptr,
1738+
xnn_unary_square,
1739+
nullptr,
16901740
remapped_ids.at(graph_node->input_id()),
16911741
remapped_ids.at(graph_node->output_id()),
16921742
graph_node->flags());
@@ -1714,9 +1764,12 @@ Error defineELUNode(
17141764

17151765
auto graph_node = node->xnode_union_as_XNNELU();
17161766

1717-
xnn_status status = xnn_define_elu(
1767+
union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}};
1768+
1769+
xnn_status status = xnn_define_unary(
17181770
subgraph_ptr,
1719-
graph_node->alpha(),
1771+
xnn_unary_elu,
1772+
&params,
17201773
remapped_ids.at(graph_node->input_id()),
17211774
remapped_ids.at(graph_node->output_id()),
17221775
graph_node->flags());
@@ -1744,8 +1797,10 @@ Error defineAbsNode(
17441797

17451798
auto graph_node = node->xnode_union_as_XNNAbs();
17461799

1747-
xnn_status status = xnn_define_abs(
1800+
xnn_status status = xnn_define_unary(
17481801
subgraph_ptr,
1802+
xnn_unary_abs,
1803+
nullptr,
17491804
remapped_ids.at(graph_node->input_id()),
17501805
remapped_ids.at(graph_node->output_id()),
17511806
graph_node->flags());

0 commit comments

Comments
 (0)