@@ -1504,6 +1504,35 @@ Error defineScaledDotProductAttentionNode(
15041504
15051505 return Error::Ok;
15061506}
1507+
1508+ /*
1509+ Defines batch matrix multiply node into the subgraph,
1510+ using the remapped ids to map the serialized ids,
1511+ to the new ids generated when defining the tensor value
1512+ */
1513+ Error defineBatchMatrixMultiplyNode (
1514+ xnn_subgraph_t subgraph_ptr,
1515+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1516+ const NodePtr node) noexcept {
1517+ auto graph_node = node->xnode_union_as_XNNBatchMatrixMultiply ();
1518+
1519+ xnn_status status = xnn_define_batch_matrix_multiply (
1520+ subgraph_ptr,
1521+ remapped_ids.at (graph_node->input1_id ()),
1522+ remapped_ids.at (graph_node->input2_id ()),
1523+ remapped_ids.at (graph_node->output_id ()),
1524+ graph_node->flags ());
1525+
1526+ ET_CHECK_OR_RETURN_ERROR (
1527+ status == xnn_status_success,
1528+ Internal,
1529+ " Failed to create BMM node %i with code: %s" ,
1530+ node->debug_handle (),
1531+ xnn_status_to_string (status));
1532+
1533+ return Error::Ok;
1534+ }
1535+
15071536/*
15081537Returns not Implemented Error code. This function is meant to be
15091538called when the compiler encountes a XNodeType from the flatbuffer
@@ -1566,6 +1595,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
15661595 _DEFINE (Concatenate4)
15671596 _DEFINE (StaticSlice)
15681597 _DEFINE (ScaledDotProductAttention)
1598+ _DEFINE (BatchMatrixMultiply)
15691599 case fb_xnnpack::XNodeUnion::NONE:
15701600 default : // Adding here as a catch all, just in case
15711601 return &defineNotImplementedNode;
0 commit comments