From f39f3ee5c9d6ab762b77f0ef640c4b5830b039e2 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Fri, 13 Jun 2025 19:30:50 -0700 Subject: [PATCH] Use wrappers from xnnpack.h for unary and binary ops (#11584) --- .gitignore | 3 + backends/xnnpack/runtime/XNNCompiler.cpp | 951 +++++++---------------- 2 files changed, 293 insertions(+), 661 deletions(-) diff --git a/.gitignore b/.gitignore index c257883ee40..553729e9b68 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,9 @@ xcuserdata/ *.xcworkspace/ *.xcframework/ +# clangd +.cache/ + # misc /.vscode/ *.so diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 312cbc17b95..7241280ab35 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -601,128 +602,6 @@ Error defineTensor( #define MAYBE_UNUSED(x) (void)(x) -/* -Define serialized add node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineAddNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - std::pair min_max = getOutputMinMax(node); - auto graph_node = node->xnode_union_as_XNNAdd(); - xnn_status status = xnn_define_add2( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create add node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define Minimum operator Node into the subgraph -*/ -Error defineMinimumNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMinimum(); - xnn_status status = xnn_define_minimum2( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create minumum node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define subtract operator Node into the subgraph -*/ -Error defineSubtractNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSubtract(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_subtract( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create subtract node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define Multiply operator Node into the subgraph -*/ -Error defineMultiplyNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMultiply(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_multiply2( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create multiply node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - #ifdef ENABLE_XNNPACK_KLEIDI bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert); @@ -843,38 +722,6 @@ Error defineFullyConnectedNode( return Error::Ok; }; -/* -Define serialized clamp node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineClampNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - std::pair min_max = getOutputMinMax(node); - auto graph_node = node->xnode_union_as_XNNClamp(); - xnn_status status = xnn_define_clamp( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create hardtanh node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Define serialized softmax node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining @@ -903,62 +750,6 @@ Error defineSoftmaxNode( return Error::Ok; } -/* -Define serialized sigmoid node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineSigmoidNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSigmoid(); - xnn_status status = xnn_define_sigmoid( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create sigmoid node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized floor node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineFloorNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNFloor(); - xnn_status status = xnn_define_floor( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create floor node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - Error defineGlobalAvgPooling2dNode( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, @@ -1155,36 +946,6 @@ Error defineMaxPooling2dNode( return Error::Ok; } -/* -Define serialized div node into the subgraph -*/ -Error defineDivNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNDiv(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_divide( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create div node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Define serialized static transpose node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1402,29 +1163,30 @@ Error defineArgMaxPooling2dNode( } /* -Define serialized square root node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized prelu node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineSquareRootNode( +Error definePReLUNode( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNSquareRoot(); + auto graph_node = node->xnode_union_as_XNNPReLU(); - xnn_status status = xnn_define_square_root( + xnn_status status = xnn_define_prelu( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create square root node %i with code: %s", + "Failed to create prelu node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1432,29 +1194,31 @@ Error defineSquareRootNode( } /* -Define serialized square root node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate2 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineReciprocalSquareRootNode( +Error defineConcatenate2Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot(); + auto graph_node = node->xnode_union_as_XNNConcatenate2(); - xnn_status status = xnn_define_reciprocal_square_root( + xnn_status status = xnn_define_concatenate2( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create reciprocal square root node %i with code: %s", + "Failed to create cat2 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1462,29 +1226,32 @@ Error defineReciprocalSquareRootNode( } /* -Define serialized log node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate3 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineLogNode( +Error defineConcatenate3Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNLog(); + auto graph_node = node->xnode_union_as_XNNConcatenate3(); - xnn_status status = xnn_define_log( + xnn_status status = xnn_define_concatenate3( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create log node %i with code: %s", + "Failed to create cat3 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1492,398 +1259,33 @@ Error defineLogNode( } /* -Define serialized gelu node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate4 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineGeluNode( +Error defineConcatenate4Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNGelu(); + auto graph_node = node->xnode_union_as_XNNConcatenate4(); - xnn_status status = xnn_define_gelu( + xnn_status status = xnn_define_concatenate4( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), + remapped_ids.at(graph_node->input4_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create gelu node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized ceiling node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineCeilingNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNCeiling(); - - xnn_status status = xnn_define_ceiling( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create ceiling node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized hardswish node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineHardswishNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNHardswish(); - - xnn_status status = xnn_define_hardswish( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create hardswish node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized leaky relu node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineLeakyReLUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNLeakyReLU(); - - xnn_status status = xnn_define_leaky_relu( - subgraph_ptr, - graph_node->negative_slope(), - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create leaky relu node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized maximum node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineMaximumNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMaximum(); - - xnn_status status = xnn_define_maximum2( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create maximum node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define Negate node into subgraph, using the remapped ids to map the -serialized ids, to the new ids generated when defining the tensor value -*/ -Error defineNegateNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNNegate(); - - xnn_status status = xnn_define_negate( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create negate node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines square node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineSquareNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSquare(); - - xnn_status status = xnn_define_square( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create square node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines square node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineELUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNELU(); - - xnn_status status = xnn_define_elu( - subgraph_ptr, - graph_node->alpha(), - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create ELU node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines absolute value node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineAbsNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNAbs(); - - xnn_status status = xnn_define_abs( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create abs node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized prelu node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error definePReLUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNPReLU(); - - xnn_status status = xnn_define_prelu( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create prelu node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate2 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate2Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate2(); - - xnn_status status = xnn_define_concatenate2( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat2 node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate3 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate3Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate3(); - - xnn_status status = xnn_define_concatenate3( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->input3_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat3 node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate4 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate4Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate4(); - - xnn_status status = xnn_define_concatenate4( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->input3_id()), - remapped_ids.at(graph_node->input4_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat4 node %i with code: %s", + "Failed to create cat4 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -2047,6 +1449,196 @@ Error defineNotImplementedNode( fb_xnnpack::EnumNameXNodeUnion(node->xnode_union_type())); } +// Generic helper function for unary operations +Error defineGenericUnaryNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + uint32_t input_id, + uint32_t output_id, + uint32_t flags, + xnn_unary_operator op_type, + const union xnn_unary_params* params, + fb_xnnpack::XNodeUnion node_type, + uint32_t debug_handle) noexcept { + xnn_status status = xnn_define_unary( + subgraph_ptr, + op_type, + params, + remapped_ids.at(input_id), + remapped_ids.at(output_id), + flags); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create %s node %i with code: %s", + fb_xnnpack::EnumNameXNodeUnion(node_type), + debug_handle, + xnn_status_to_string(status)); + + return Error::Ok; +} + +// Macro for unary operations with no parameters +#define _DEFINE_UNARY_NODE_NO_PARAMS(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + op_type, \ + nullptr, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with min/max parameters +#define _DEFINE_UNARY_NODE_WITH_MINMAX(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + std::pair min_max = getOutputMinMax(node); \ + union xnn_unary_params params = { \ + .clamp = {.min = min_max.first, .max = min_max.second}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + op_type, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with leaky_relu parameters +#define _DEFINE_UNARY_NODE_WITH_LEAKY_RELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNLeakyReLU(); \ + union xnn_unary_params params = { \ + .leaky_relu = {.negative_slope = graph_node->negative_slope()}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_leaky_relu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with elu parameters +#define _DEFINE_UNARY_NODE_WITH_ELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNELU(); \ + union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_elu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Generic helper function for binary operations +Error defineGenericBinaryNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const fb_xnnpack::_XNNNode2x1* graph_node, + xnn_binary_operator op_type, + const struct xnn_binary_params* params, + fb_xnnpack::XNodeUnion node_type, + uint32_t debug_handle) noexcept { + xnn_status status = xnn_define_binary( + subgraph_ptr, + op_type, + params, + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create %s node %i with code: %s", + fb_xnnpack::EnumNameXNodeUnion(node_type), + debug_handle, + xnn_status_to_string(status)); + + return Error::Ok; +} + +// Macro for binary operations with min/max parameters +#define _DEFINE_BINARY_NODE_WITH_MINMAX(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + std::pair min_max = getOutputMinMax(node); \ + struct xnn_binary_params params = { \ + .output_min = min_max.first, .output_max = min_max.second}; \ + return defineGenericBinaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node, \ + op_type, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for binary operations without parameters +#define _DEFINE_BINARY_NODE_NO_PARAMS(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + return defineGenericBinaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node, \ + op_type, \ + nullptr, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + /* Returns the pointer to the defineNode function that handles the given XNode type @@ -2055,43 +1647,80 @@ XNode type case fb_xnnpack::XNodeUnion::XNN##name: \ return &define##name##Node; +// Unary Ops with no params +_DEFINE_UNARY_NODE_NO_PARAMS(Sigmoid, xnn_unary_sigmoid) +_DEFINE_UNARY_NODE_NO_PARAMS(Floor, xnn_unary_floor) +_DEFINE_UNARY_NODE_NO_PARAMS(SquareRoot, xnn_unary_square_root) +_DEFINE_UNARY_NODE_NO_PARAMS( + ReciprocalSquareRoot, + xnn_unary_reciprocal_square_root) +_DEFINE_UNARY_NODE_NO_PARAMS(Ceiling, xnn_unary_ceiling) +_DEFINE_UNARY_NODE_NO_PARAMS(Gelu, xnn_unary_gelu) +_DEFINE_UNARY_NODE_NO_PARAMS(Hardswish, xnn_unary_hardswish) +_DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log) +_DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate) +_DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square) +_DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs) + +// Unary Ops with min/max params +_DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp) + +// Unary Ops with specific params +_DEFINE_UNARY_NODE_WITH_LEAKY_RELU(LeakyReLU) +_DEFINE_UNARY_NODE_WITH_ELU(ELU) + +// Binary Ops with params +_DEFINE_BINARY_NODE_WITH_MINMAX(Add, xnn_binary_add) +_DEFINE_BINARY_NODE_WITH_MINMAX(Subtract, xnn_binary_subtract) +_DEFINE_BINARY_NODE_WITH_MINMAX(Multiply, xnn_binary_multiply) +_DEFINE_BINARY_NODE_WITH_MINMAX(Div, xnn_binary_divide) + +// Binary Ops without params +_DEFINE_BINARY_NODE_NO_PARAMS(Minimum, xnn_binary_minimum) +_DEFINE_BINARY_NODE_NO_PARAMS(Maximum, xnn_binary_maximum) + DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { switch (nodeType) { + // Binary ops _DEFINE(Add) - _DEFINE(FullyConnected) + _DEFINE(Subtract) + _DEFINE(Multiply) + _DEFINE(Div) + _DEFINE(Minimum) + _DEFINE(Maximum) + + // Unary ops _DEFINE(Softmax) + _DEFINE(SquareRoot) + _DEFINE(ReciprocalSquareRoot) + _DEFINE(Ceiling) + _DEFINE(Gelu) + _DEFINE(Hardswish) + _DEFINE(Log) + _DEFINE(Negate) + _DEFINE(Square) + _DEFINE(Clamp) + _DEFINE(LeakyReLU) + _DEFINE(ELU) + _DEFINE(Abs) + _DEFINE(Floor) + _DEFINE(PReLU) _DEFINE(Sigmoid) + + // Others + _DEFINE(FullyConnected) _DEFINE(StaticTranspose) - _DEFINE(Clamp) _DEFINE(Conv2d) _DEFINE(ConvTranspose2d) - _DEFINE(Div) _DEFINE(StaticResizeBilinear2D) _DEFINE(StaticConstantPad) _DEFINE(AvgPooling2d) - _DEFINE(Minimum) _DEFINE(DepthwiseConv2d) _DEFINE(MaxPooling2d) - _DEFINE(Multiply) - _DEFINE(Subtract) - _DEFINE(Floor) _DEFINE(Convert) _DEFINE(GlobalAvgPooling2d) _DEFINE(StaticReshape) _DEFINE(ArgMaxPooling2d) - _DEFINE(SquareRoot) - _DEFINE(ReciprocalSquareRoot) - _DEFINE(Ceiling) - _DEFINE(Gelu) - _DEFINE(Hardswish) - _DEFINE(LeakyReLU) - _DEFINE(Log) - _DEFINE(Maximum) - _DEFINE(Negate) - _DEFINE(Square) - _DEFINE(ELU) - _DEFINE(Abs) - _DEFINE(PReLU) _DEFINE(Concatenate2) _DEFINE(Concatenate3) _DEFINE(Concatenate4)