1111#include  < executorch/backends/xnnpack/serialization/schema_generated.h> 
1212#include  < executorch/extension/threadpool/threadpool.h> 
1313#include  < executorch/runtime/executor/pte_data_map.h> 
14+ #include  < xnnpack.h> 
1415#include  < string> 
1516#include  < unordered_map> 
1617#include  < vector> 
@@ -615,10 +616,14 @@ Error defineAddNode(
615616
616617  std::pair<float , float > min_max = getOutputMinMax (node);
617618  auto  graph_node = node->xnode_union_as_XNNAdd ();
618-   xnn_status status = xnn_define_add2 (
619+ 
620+   struct  xnn_binary_params  params = {
621+       .output_min  = min_max.first , .output_max  = min_max.second };
622+ 
623+   xnn_status status = xnn_define_binary (
619624      subgraph_ptr,
620-       min_max. first ,
621-       min_max. second ,
625+       xnn_binary_add ,
626+       ¶ms ,
622627      remapped_ids.at (graph_node->input1_id ()),
623628      remapped_ids.at (graph_node->input2_id ()),
624629      remapped_ids.at (graph_node->output_id ()),
@@ -644,8 +649,11 @@ Error defineMinimumNode(
644649  MAYBE_UNUSED (graph);
645650
646651  auto  graph_node = node->xnode_union_as_XNNMinimum ();
647-   xnn_status status = xnn_define_minimum2 (
652+ 
653+   xnn_status status = xnn_define_binary (
648654      subgraph_ptr,
655+       xnn_binary_minimum,
656+       nullptr ,
649657      remapped_ids.at (graph_node->input1_id ()),
650658      remapped_ids.at (graph_node->input2_id ()),
651659      remapped_ids.at (graph_node->output_id ()),
@@ -673,10 +681,14 @@ Error defineSubtractNode(
673681
674682  auto  graph_node = node->xnode_union_as_XNNSubtract ();
675683  std::pair<float , float > min_max = getOutputMinMax (node);
676-   xnn_status status = xnn_define_subtract (
684+ 
685+   struct  xnn_binary_params  params = {
686+       .output_min  = min_max.first , .output_max  = min_max.second };
687+ 
688+   xnn_status status = xnn_define_binary (
677689      subgraph_ptr,
678-       min_max. first ,
679-       min_max. second ,
690+       xnn_binary_subtract ,
691+       ¶ms ,
680692      remapped_ids.at (graph_node->input1_id ()),
681693      remapped_ids.at (graph_node->input2_id ()),
682694      remapped_ids.at (graph_node->output_id ()),
@@ -704,10 +716,14 @@ Error defineMultiplyNode(
704716
705717  auto  graph_node = node->xnode_union_as_XNNMultiply ();
706718  std::pair<float , float > min_max = getOutputMinMax (node);
707-   xnn_status status = xnn_define_multiply2 (
719+ 
720+   struct  xnn_binary_params  params = {
721+       .output_min  = min_max.first , .output_max  = min_max.second };
722+ 
723+   xnn_status status = xnn_define_binary (
708724      subgraph_ptr,
709-       min_max. first ,
710-       min_max. second ,
725+       xnn_binary_multiply ,
726+       ¶ms ,
711727      remapped_ids.at (graph_node->input1_id ()),
712728      remapped_ids.at (graph_node->input2_id ()),
713729      remapped_ids.at (graph_node->output_id ()),
@@ -857,10 +873,14 @@ Error defineClampNode(
857873
858874  std::pair<float , float > min_max = getOutputMinMax (node);
859875  auto  graph_node = node->xnode_union_as_XNNClamp ();
860-   xnn_status status = xnn_define_clamp (
876+ 
877+   union  xnn_unary_params params = {
878+       .clamp  = {.min  = min_max.first , .max  = min_max.second }};
879+ 
880+   xnn_status status = xnn_define_unary (
861881      subgraph_ptr,
862-       min_max. first ,
863-       min_max. second ,
882+       xnn_unary_clamp ,
883+       ¶ms ,
864884      remapped_ids.at (graph_node->input_id ()),
865885      remapped_ids.at (graph_node->output_id ()),
866886      graph_node->flags ());
@@ -916,8 +936,10 @@ Error defineSigmoidNode(
916936  MAYBE_UNUSED (graph);
917937
918938  auto  graph_node = node->xnode_union_as_XNNSigmoid ();
919-   xnn_status status = xnn_define_sigmoid (
939+   xnn_status status = xnn_define_unary (
920940      subgraph_ptr,
941+       xnn_unary_sigmoid,
942+       nullptr ,
921943      remapped_ids.at (graph_node->input_id ()),
922944      remapped_ids.at (graph_node->output_id ()),
923945      graph_node->flags ());
@@ -944,8 +966,10 @@ Error defineFloorNode(
944966  MAYBE_UNUSED (graph);
945967
946968  auto  graph_node = node->xnode_union_as_XNNFloor ();
947-   xnn_status status = xnn_define_floor (
969+   xnn_status status = xnn_define_unary (
948970      subgraph_ptr,
971+       xnn_unary_floor,
972+       nullptr ,
949973      remapped_ids.at (graph_node->input_id ()),
950974      remapped_ids.at (graph_node->output_id ()),
951975      graph_node->flags ());
@@ -1167,10 +1191,14 @@ Error defineDivNode(
11671191
11681192  auto  graph_node = node->xnode_union_as_XNNDiv ();
11691193  std::pair<float , float > min_max = getOutputMinMax (node);
1170-   xnn_status status = xnn_define_divide (
1194+ 
1195+   struct  xnn_binary_params  params = {
1196+       .output_min  = min_max.first , .output_max  = min_max.second };
1197+ 
1198+   xnn_status status = xnn_define_binary (
11711199      subgraph_ptr,
1172-       min_max. first ,
1173-       min_max. second ,
1200+       xnn_binary_divide ,
1201+       ¶ms ,
11741202      remapped_ids.at (graph_node->input1_id ()),
11751203      remapped_ids.at (graph_node->input2_id ()),
11761204      remapped_ids.at (graph_node->output_id ()),
@@ -1415,8 +1443,10 @@ Error defineSquareRootNode(
14151443
14161444  auto  graph_node = node->xnode_union_as_XNNSquareRoot ();
14171445
1418-   xnn_status status = xnn_define_square_root (
1446+   xnn_status status = xnn_define_unary (
14191447      subgraph_ptr,
1448+       xnn_unary_square_root,
1449+       nullptr ,
14201450      remapped_ids.at (graph_node->input_id ()),
14211451      remapped_ids.at (graph_node->output_id ()),
14221452      graph_node->flags ());
@@ -1445,8 +1475,10 @@ Error defineReciprocalSquareRootNode(
14451475
14461476  auto  graph_node = node->xnode_union_as_XNNReciprocalSquareRoot ();
14471477
1448-   xnn_status status = xnn_define_reciprocal_square_root (
1478+   xnn_status status = xnn_define_unary (
14491479      subgraph_ptr,
1480+       xnn_unary_reciprocal_square_root,
1481+       nullptr ,
14501482      remapped_ids.at (graph_node->input_id ()),
14511483      remapped_ids.at (graph_node->output_id ()),
14521484      graph_node->flags ());
@@ -1475,8 +1507,10 @@ Error defineLogNode(
14751507
14761508  auto  graph_node = node->xnode_union_as_XNNLog ();
14771509
1478-   xnn_status status = xnn_define_log (
1510+   xnn_status status = xnn_define_unary (
14791511      subgraph_ptr,
1512+       xnn_unary_log,
1513+       nullptr ,
14801514      remapped_ids.at (graph_node->input_id ()),
14811515      remapped_ids.at (graph_node->output_id ()),
14821516      graph_node->flags ());
@@ -1505,8 +1539,10 @@ Error defineGeluNode(
15051539
15061540  auto  graph_node = node->xnode_union_as_XNNGelu ();
15071541
1508-   xnn_status status = xnn_define_gelu (
1542+   xnn_status status = xnn_define_unary (
15091543      subgraph_ptr,
1544+       xnn_unary_gelu,
1545+       nullptr ,
15101546      remapped_ids.at (graph_node->input_id ()),
15111547      remapped_ids.at (graph_node->output_id ()),
15121548      graph_node->flags ());
@@ -1535,8 +1571,10 @@ Error defineCeilingNode(
15351571
15361572  auto  graph_node = node->xnode_union_as_XNNCeiling ();
15371573
1538-   xnn_status status = xnn_define_ceiling (
1574+   xnn_status status = xnn_define_unary (
15391575      subgraph_ptr,
1576+       xnn_unary_ceiling,
1577+       nullptr ,
15401578      remapped_ids.at (graph_node->input_id ()),
15411579      remapped_ids.at (graph_node->output_id ()),
15421580      graph_node->flags ());
@@ -1565,8 +1603,10 @@ Error defineHardswishNode(
15651603
15661604  auto  graph_node = node->xnode_union_as_XNNHardswish ();
15671605
1568-   xnn_status status = xnn_define_hardswish (
1606+   xnn_status status = xnn_define_unary (
15691607      subgraph_ptr,
1608+       xnn_unary_hardswish,
1609+       nullptr ,
15701610      remapped_ids.at (graph_node->input_id ()),
15711611      remapped_ids.at (graph_node->output_id ()),
15721612      graph_node->flags ());
@@ -1595,9 +1635,13 @@ Error defineLeakyReLUNode(
15951635
15961636  auto  graph_node = node->xnode_union_as_XNNLeakyReLU ();
15971637
1598-   xnn_status status = xnn_define_leaky_relu (
1638+   union  xnn_unary_params params = {
1639+       .leaky_relu  = {.negative_slope  = graph_node->negative_slope ()}};
1640+ 
1641+   xnn_status status = xnn_define_unary (
15991642      subgraph_ptr,
1600-       graph_node->negative_slope (),
1643+       xnn_unary_leaky_relu,
1644+       ¶ms,
16011645      remapped_ids.at (graph_node->input_id ()),
16021646      remapped_ids.at (graph_node->output_id ()),
16031647      graph_node->flags ());
@@ -1626,8 +1670,10 @@ Error defineMaximumNode(
16261670
16271671  auto  graph_node = node->xnode_union_as_XNNMaximum ();
16281672
1629-   xnn_status status = xnn_define_maximum2 (
1673+   xnn_status status = xnn_define_binary (
16301674      subgraph_ptr,
1675+       xnn_binary_maximum,
1676+       nullptr ,
16311677      remapped_ids.at (graph_node->input1_id ()),
16321678      remapped_ids.at (graph_node->input2_id ()),
16331679      remapped_ids.at (graph_node->output_id ()),
@@ -1656,8 +1702,10 @@ Error defineNegateNode(
16561702
16571703  auto  graph_node = node->xnode_union_as_XNNNegate ();
16581704
1659-   xnn_status status = xnn_define_negate (
1705+   xnn_status status = xnn_define_unary (
16601706      subgraph_ptr,
1707+       xnn_unary_negate,
1708+       nullptr ,
16611709      remapped_ids.at (graph_node->input_id ()),
16621710      remapped_ids.at (graph_node->output_id ()),
16631711      graph_node->flags ());
@@ -1685,8 +1733,10 @@ Error defineSquareNode(
16851733
16861734  auto  graph_node = node->xnode_union_as_XNNSquare ();
16871735
1688-   xnn_status status = xnn_define_square (
1736+   xnn_status status = xnn_define_unary (
16891737      subgraph_ptr,
1738+       xnn_unary_square,
1739+       nullptr ,
16901740      remapped_ids.at (graph_node->input_id ()),
16911741      remapped_ids.at (graph_node->output_id ()),
16921742      graph_node->flags ());
@@ -1714,9 +1764,12 @@ Error defineELUNode(
17141764
17151765  auto  graph_node = node->xnode_union_as_XNNELU ();
17161766
1717-   xnn_status status = xnn_define_elu (
1767+   union  xnn_unary_params params = {.elu  = {.alpha  = graph_node->alpha ()}};
1768+ 
1769+   xnn_status status = xnn_define_unary (
17181770      subgraph_ptr,
1719-       graph_node->alpha (),
1771+       xnn_unary_elu,
1772+       ¶ms,
17201773      remapped_ids.at (graph_node->input_id ()),
17211774      remapped_ids.at (graph_node->output_id ()),
17221775      graph_node->flags ());
@@ -1744,8 +1797,10 @@ Error defineAbsNode(
17441797
17451798  auto  graph_node = node->xnode_union_as_XNNAbs ();
17461799
1747-   xnn_status status = xnn_define_abs (
1800+   xnn_status status = xnn_define_unary (
17481801      subgraph_ptr,
1802+       xnn_unary_abs,
1803+       nullptr ,
17491804      remapped_ids.at (graph_node->input_id ()),
17501805      remapped_ids.at (graph_node->output_id ()),
17511806      graph_node->flags ());
0 commit comments