Skip to content

Commit 464da84

Browse files
Use wrappers from xnnpack.h for unary and binary ops (pytorch#11584)
1 parent 289acbd commit 464da84

File tree

1 file changed

+86
-32
lines changed

1 file changed

+86
-32
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,14 @@ Error defineAddNode(
615615

616616
std::pair<float, float> min_max = getOutputMinMax(node);
617617
auto graph_node = node->xnode_union_as_XNNAdd();
618-
xnn_status status = xnn_define_add2(
618+
619+
struct xnn_binary_params params = {
620+
.output_min = min_max.first, .output_max = min_max.second};
621+
622+
xnn_status status = xnn_define_binary(
619623
subgraph_ptr,
620-
min_max.first,
621-
min_max.second,
624+
xnn_binary_add,
625+
&params,
622626
remapped_ids.at(graph_node->input1_id()),
623627
remapped_ids.at(graph_node->input2_id()),
624628
remapped_ids.at(graph_node->output_id()),
@@ -644,8 +648,11 @@ Error defineMinimumNode(
644648
MAYBE_UNUSED(graph);
645649

646650
auto graph_node = node->xnode_union_as_XNNMinimum();
647-
xnn_status status = xnn_define_minimum2(
651+
652+
xnn_status status = xnn_define_binary(
648653
subgraph_ptr,
654+
xnn_binary_minimum,
655+
nullptr,
649656
remapped_ids.at(graph_node->input1_id()),
650657
remapped_ids.at(graph_node->input2_id()),
651658
remapped_ids.at(graph_node->output_id()),
@@ -673,10 +680,14 @@ Error defineSubtractNode(
673680

674681
auto graph_node = node->xnode_union_as_XNNSubtract();
675682
std::pair<float, float> min_max = getOutputMinMax(node);
676-
xnn_status status = xnn_define_subtract(
683+
684+
struct xnn_binary_params params = {
685+
.output_min = min_max.first, .output_max = min_max.second};
686+
687+
xnn_status status = xnn_define_binary(
677688
subgraph_ptr,
678-
min_max.first,
679-
min_max.second,
689+
xnn_binary_subtract,
690+
&params,
680691
remapped_ids.at(graph_node->input1_id()),
681692
remapped_ids.at(graph_node->input2_id()),
682693
remapped_ids.at(graph_node->output_id()),
@@ -704,10 +715,14 @@ Error defineMultiplyNode(
704715

705716
auto graph_node = node->xnode_union_as_XNNMultiply();
706717
std::pair<float, float> min_max = getOutputMinMax(node);
707-
xnn_status status = xnn_define_multiply2(
718+
719+
struct xnn_binary_params params = {
720+
.output_min = min_max.first, .output_max = min_max.second};
721+
722+
xnn_status status = xnn_define_binary(
708723
subgraph_ptr,
709-
min_max.first,
710-
min_max.second,
724+
xnn_binary_multiply,
725+
&params,
711726
remapped_ids.at(graph_node->input1_id()),
712727
remapped_ids.at(graph_node->input2_id()),
713728
remapped_ids.at(graph_node->output_id()),
@@ -857,10 +872,14 @@ Error defineClampNode(
857872

858873
std::pair<float, float> min_max = getOutputMinMax(node);
859874
auto graph_node = node->xnode_union_as_XNNClamp();
860-
xnn_status status = xnn_define_clamp(
875+
876+
union xnn_unary_params params = {
877+
.clamp = {.min = min_max.first, .max = min_max.second}};
878+
879+
xnn_status status = xnn_define_unary(
861880
subgraph_ptr,
862-
min_max.first,
863-
min_max.second,
881+
xnn_unary_clamp,
882+
&params,
864883
remapped_ids.at(graph_node->input_id()),
865884
remapped_ids.at(graph_node->output_id()),
866885
graph_node->flags());
@@ -916,8 +935,10 @@ Error defineSigmoidNode(
916935
MAYBE_UNUSED(graph);
917936

918937
auto graph_node = node->xnode_union_as_XNNSigmoid();
919-
xnn_status status = xnn_define_sigmoid(
938+
xnn_status status = xnn_define_unary(
920939
subgraph_ptr,
940+
xnn_unary_sigmoid,
941+
nullptr,
921942
remapped_ids.at(graph_node->input_id()),
922943
remapped_ids.at(graph_node->output_id()),
923944
graph_node->flags());
@@ -944,8 +965,10 @@ Error defineFloorNode(
944965
MAYBE_UNUSED(graph);
945966

946967
auto graph_node = node->xnode_union_as_XNNFloor();
947-
xnn_status status = xnn_define_floor(
968+
xnn_status status = xnn_define_unary(
948969
subgraph_ptr,
970+
xnn_unary_floor,
971+
nullptr,
949972
remapped_ids.at(graph_node->input_id()),
950973
remapped_ids.at(graph_node->output_id()),
951974
graph_node->flags());
@@ -1167,10 +1190,14 @@ Error defineDivNode(
11671190

11681191
auto graph_node = node->xnode_union_as_XNNDiv();
11691192
std::pair<float, float> min_max = getOutputMinMax(node);
1170-
xnn_status status = xnn_define_divide(
1193+
1194+
struct xnn_binary_params params = {
1195+
.output_min = min_max.first, .output_max = min_max.second};
1196+
1197+
xnn_status status = xnn_define_binary(
11711198
subgraph_ptr,
1172-
min_max.first,
1173-
min_max.second,
1199+
xnn_binary_divide,
1200+
&params,
11741201
remapped_ids.at(graph_node->input1_id()),
11751202
remapped_ids.at(graph_node->input2_id()),
11761203
remapped_ids.at(graph_node->output_id()),
@@ -1415,8 +1442,10 @@ Error defineSquareRootNode(
14151442

14161443
auto graph_node = node->xnode_union_as_XNNSquareRoot();
14171444

1418-
xnn_status status = xnn_define_square_root(
1445+
xnn_status status = xnn_define_unary(
14191446
subgraph_ptr,
1447+
xnn_unary_square_root,
1448+
nullptr,
14201449
remapped_ids.at(graph_node->input_id()),
14211450
remapped_ids.at(graph_node->output_id()),
14221451
graph_node->flags());
@@ -1445,8 +1474,10 @@ Error defineReciprocalSquareRootNode(
14451474

14461475
auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot();
14471476

1448-
xnn_status status = xnn_define_reciprocal_square_root(
1477+
xnn_status status = xnn_define_unary(
14491478
subgraph_ptr,
1479+
xnn_unary_reciprocal_square_root,
1480+
nullptr,
14501481
remapped_ids.at(graph_node->input_id()),
14511482
remapped_ids.at(graph_node->output_id()),
14521483
graph_node->flags());
@@ -1475,8 +1506,10 @@ Error defineLogNode(
14751506

14761507
auto graph_node = node->xnode_union_as_XNNLog();
14771508

1478-
xnn_status status = xnn_define_log(
1509+
xnn_status status = xnn_define_unary(
14791510
subgraph_ptr,
1511+
xnn_unary_log,
1512+
nullptr,
14801513
remapped_ids.at(graph_node->input_id()),
14811514
remapped_ids.at(graph_node->output_id()),
14821515
graph_node->flags());
@@ -1505,8 +1538,10 @@ Error defineGeluNode(
15051538

15061539
auto graph_node = node->xnode_union_as_XNNGelu();
15071540

1508-
xnn_status status = xnn_define_gelu(
1541+
xnn_status status = xnn_define_unary(
15091542
subgraph_ptr,
1543+
xnn_unary_gelu,
1544+
nullptr,
15101545
remapped_ids.at(graph_node->input_id()),
15111546
remapped_ids.at(graph_node->output_id()),
15121547
graph_node->flags());
@@ -1535,8 +1570,10 @@ Error defineCeilingNode(
15351570

15361571
auto graph_node = node->xnode_union_as_XNNCeiling();
15371572

1538-
xnn_status status = xnn_define_ceiling(
1573+
xnn_status status = xnn_define_unary(
15391574
subgraph_ptr,
1575+
xnn_unary_ceiling,
1576+
nullptr,
15401577
remapped_ids.at(graph_node->input_id()),
15411578
remapped_ids.at(graph_node->output_id()),
15421579
graph_node->flags());
@@ -1565,8 +1602,10 @@ Error defineHardswishNode(
15651602

15661603
auto graph_node = node->xnode_union_as_XNNHardswish();
15671604

1568-
xnn_status status = xnn_define_hardswish(
1605+
xnn_status status = xnn_define_unary(
15691606
subgraph_ptr,
1607+
xnn_unary_hardswish,
1608+
nullptr,
15701609
remapped_ids.at(graph_node->input_id()),
15711610
remapped_ids.at(graph_node->output_id()),
15721611
graph_node->flags());
@@ -1595,9 +1634,13 @@ Error defineLeakyReLUNode(
15951634

15961635
auto graph_node = node->xnode_union_as_XNNLeakyReLU();
15971636

1598-
xnn_status status = xnn_define_leaky_relu(
1637+
union xnn_unary_params params = {
1638+
.leaky_relu = {.negative_slope = graph_node->negative_slope()}};
1639+
1640+
xnn_status status = xnn_define_unary(
15991641
subgraph_ptr,
1600-
graph_node->negative_slope(),
1642+
xnn_unary_leaky_relu,
1643+
&params,
16011644
remapped_ids.at(graph_node->input_id()),
16021645
remapped_ids.at(graph_node->output_id()),
16031646
graph_node->flags());
@@ -1626,8 +1669,10 @@ Error defineMaximumNode(
16261669

16271670
auto graph_node = node->xnode_union_as_XNNMaximum();
16281671

1629-
xnn_status status = xnn_define_maximum2(
1672+
xnn_status status = xnn_define_binary(
16301673
subgraph_ptr,
1674+
xnn_binary_maximum,
1675+
nullptr,
16311676
remapped_ids.at(graph_node->input1_id()),
16321677
remapped_ids.at(graph_node->input2_id()),
16331678
remapped_ids.at(graph_node->output_id()),
@@ -1656,8 +1701,10 @@ Error defineNegateNode(
16561701

16571702
auto graph_node = node->xnode_union_as_XNNNegate();
16581703

1659-
xnn_status status = xnn_define_negate(
1704+
xnn_status status = xnn_define_unary(
16601705
subgraph_ptr,
1706+
xnn_unary_negate,
1707+
nullptr,
16611708
remapped_ids.at(graph_node->input_id()),
16621709
remapped_ids.at(graph_node->output_id()),
16631710
graph_node->flags());
@@ -1685,8 +1732,10 @@ Error defineSquareNode(
16851732

16861733
auto graph_node = node->xnode_union_as_XNNSquare();
16871734

1688-
xnn_status status = xnn_define_square(
1735+
xnn_status status = xnn_define_unary(
16891736
subgraph_ptr,
1737+
xnn_unary_square,
1738+
nullptr,
16901739
remapped_ids.at(graph_node->input_id()),
16911740
remapped_ids.at(graph_node->output_id()),
16921741
graph_node->flags());
@@ -1714,9 +1763,12 @@ Error defineELUNode(
17141763

17151764
auto graph_node = node->xnode_union_as_XNNELU();
17161765

1717-
xnn_status status = xnn_define_elu(
1766+
union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}};
1767+
1768+
xnn_status status = xnn_define_unary(
17181769
subgraph_ptr,
1719-
graph_node->alpha(),
1770+
xnn_unary_elu,
1771+
&params,
17201772
remapped_ids.at(graph_node->input_id()),
17211773
remapped_ids.at(graph_node->output_id()),
17221774
graph_node->flags());
@@ -1744,8 +1796,10 @@ Error defineAbsNode(
17441796

17451797
auto graph_node = node->xnode_union_as_XNNAbs();
17461798

1747-
xnn_status status = xnn_define_abs(
1799+
xnn_status status = xnn_define_unary(
17481800
subgraph_ptr,
1801+
xnn_unary_abs,
1802+
nullptr,
17491803
remapped_ids.at(graph_node->input_id()),
17501804
remapped_ids.at(graph_node->output_id()),
17511805
graph_node->flags());

0 commit comments

Comments
 (0)