Skip to content

Commit a70d070

Browse files
authored
add fp32 bmm op
Differential Revision: D60153721 Pull Request resolved: #4604
1 parent 593da70 commit a70d070

File tree

9 files changed

+153
-0
lines changed

9 files changed

+153
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
op_add,
1111
op_addmm,
1212
op_avg_pooling2d,
13+
op_bmm,
1314
op_cat,
1415
op_ceiling,
1516
op_clamp,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNBatchMatrixMultiply,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class BMMVisitor(NodeVisitor):
24+
target = "aten.bmm.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
37+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
38+
39+
# input1
40+
input1_id = vals_to_ids[get_input_node(node, 0)]
41+
42+
# input2
43+
input2_id = vals_to_ids[get_input_node(node, 1)]
44+
45+
# output
46+
output_id = vals_to_ids[node]
47+
48+
ser_node = XNode(
49+
xnode_union=XNNBatchMatrixMultiply(
50+
input1_id=input1_id, input2_id=input2_id, output_id=output_id, flags=0
51+
),
52+
debug_handle=debug_handle,
53+
)
54+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AbsConfig,
1818
AddConfig,
1919
AvgPoolingConfig,
20+
BMMConfig,
2021
CatConfig,
2122
CeilConfig,
2223
ClampConfig,
@@ -60,6 +61,7 @@
6061
AddmmConfig,
6162
AvgPoolingConfig,
6263
BatchNormConfig,
64+
BMMConfig,
6365
CatConfig,
6466
CeilConfig,
6567
ConstantPadConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,15 @@ class SubConfig(GenericNodePartitionerConfig):
403403

404404
def supported_precision_types(self) -> List[ConfigPrecisionType]:
405405
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
406+
407+
408+
class BMMConfig(GenericNodePartitionerConfig):
409+
"""
410+
Despite being a GEMM Kernel, BMM Can be partitioned like a single node partitioner
411+
because it does not perform any packing on the inputs being matrix multiplied
412+
"""
413+
414+
target_name = "bmm.default"
415+
416+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
417+
return [ConfigPrecisionType.FP32]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,35 @@ Error defineScaledDotProductAttentionNode(
15041504

15051505
return Error::Ok;
15061506
}
1507+
1508+
/*
1509+
Defines batch matrix multiply node into the subgraph,
1510+
using the remapped ids to map the serialized ids,
1511+
to the new ids generated when defining the tensor value
1512+
*/
1513+
Error defineBatchMatrixMultiplyNode(
1514+
xnn_subgraph_t subgraph_ptr,
1515+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1516+
const NodePtr node) noexcept {
1517+
auto graph_node = node->xnode_union_as_XNNBatchMatrixMultiply();
1518+
1519+
xnn_status status = xnn_define_batch_matrix_multiply(
1520+
subgraph_ptr,
1521+
remapped_ids.at(graph_node->input1_id()),
1522+
remapped_ids.at(graph_node->input2_id()),
1523+
remapped_ids.at(graph_node->output_id()),
1524+
graph_node->flags());
1525+
1526+
ET_CHECK_OR_RETURN_ERROR(
1527+
status == xnn_status_success,
1528+
Internal,
1529+
"Failed to create BMM node %i with code: %s",
1530+
node->debug_handle(),
1531+
xnn_status_to_string(status));
1532+
1533+
return Error::Ok;
1534+
}
1535+
15071536
/*
15081537
Returns not Implemented Error code. This function is meant to be
15091538
called when the compiler encountes a XNodeType from the flatbuffer
@@ -1566,6 +1595,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
15661595
_DEFINE(Concatenate4)
15671596
_DEFINE(StaticSlice)
15681597
_DEFINE(ScaledDotProductAttention)
1598+
_DEFINE(BatchMatrixMultiply)
15691599
case fb_xnnpack::XNodeUnion::NONE:
15701600
default: // Adding here as a catch all, just in case
15711601
return &defineNotImplementedNode;

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ union XNodeUnion {
134134
XNNConcatenate4: _XNNCat,
135135
XNNStaticSlice,
136136
XNNScaledDotProductAttention,
137+
XNNBatchMatrixMultiply: _XNNNode2x1,
137138
}
138139

139140
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ union XNodeUnion {
130130
XNNConcatenate4: _XNNCat,
131131
XNNStaticSlice,
132132
XNNScaledDotProductAttention,
133+
XNNBatchMatrixMultiply: _XNNNode2x1,
133134
}
134135

135136
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ class XNNConcatenate4(XNNCat):
177177
pass
178178

179179

180+
@dataclass
181+
class XNNBatchMatrixMultiply(XNNNode2x1):
182+
pass
183+
184+
180185
@dataclass
181186
class XNNStaticTranspose:
182187
num_dims: int
@@ -354,6 +359,7 @@ class XNNScaledDotProductAttention:
354359
XNNConcatenate4,
355360
XNNStaticSlice,
356361
XNNScaledDotProductAttention,
362+
XNNBatchMatrixMultiply,
357363
]
358364

359365

backends/xnnpack/test/ops/bmm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestBMM(unittest.TestCase):
14+
class BMM(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
18+
def forward(self, x, y):
19+
return torch.bmm(x, y)
20+
21+
def _test_bmm(self, inputs):
22+
(
23+
Tester(self.BMM(), inputs)
24+
.export()
25+
.check_count({"torch.ops.aten.bmm.default": 1})
26+
.to_edge_transform_and_lower()
27+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
28+
.check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
29+
.to_executorch()
30+
.serialize()
31+
.run_method_and_compare_outputs()
32+
)
33+
34+
def test_fp16_bmm(self):
35+
inputs = (
36+
torch.randn(2, 3, 4).to(torch.float16),
37+
torch.randn(2, 4, 6).to(torch.float16),
38+
)
39+
self._test_bmm(inputs)
40+
41+
def test_fp32_bmm(self):
42+
inputs = (
43+
torch.randn(2, 3, 4),
44+
torch.randn(2, 4, 6),
45+
)
46+
self._test_bmm(inputs)

0 commit comments

Comments
 (0)