Skip to content

Commit 3426f64

Browse files
Support activation broadcasting in XNNPACK Matmul (microsoft#24908)
### Description 1. Support activation broadcasting in XNNPACK Matmul 2. Fix a subtle bug when activations is 1-D Per the existing gating logic, 1-D activations were allowed but the batch being passed through did not account for it. The batch size passed in was always `a->Shape()[0]` which is actually passing in the reduction dimension (K). This is incorrect as for a 1-D activation input, a `1` is to be prepended to the shape which meant that we should have actually passed in `1` for the batch. This passed the relevant test but I think it would have written outside the bounds of the output buffer because of the non-unary batch being passed through. ### Motivation and Context Resolve microsoft#24107 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 36fc8c8 commit 3426f64

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

onnxruntime/core/providers/xnnpack/math/matmul.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g
4141
break;
4242
}
4343

44-
if (A_shape == nullptr || A_shape->dim_size() > 2 ||
45-
(A_shape->dim_size() == 2 && A_shape->dim(1).dim_value() == 0) ||
46-
A_shape->dim(0).dim_value() == 0) {
44+
// A must at-least be 1-D
45+
if (A_shape == nullptr || A_shape->dim_size() < 1) {
4746
break;
4847
}
4948

@@ -162,10 +161,28 @@ Status MatMul::Compute(OpKernelContext* ctx) const {
162161
xnn_status status = xnn_status_success;
163162

164163
pthreadpool_t threadpool = GetThreadPool();
164+
165+
// If the input 'A' is 1-D, then it is prepended with 1 and hence batch will be 1
166+
size_t batch = 1;
167+
168+
const auto& a_dims = a->Shape();
169+
int64_t rank = a_dims.NumDimensions();
170+
171+
if (rank == 2) {
172+
batch = a_dims[0];
173+
} else if (rank > 2) {
174+
// Input 'A' is N-dimensional, the batch is made up of the product of the outermost dims
175+
// (excluding the actual inner reduction dim)
176+
177+
for (int64_t i = 0; i < rank - 1; ++i) {
178+
batch *= a_dims[i];
179+
}
180+
}
181+
165182
if (op_type_ == OpComputeType::op_compute_type_fp32) {
166-
status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool);
183+
status = xnn_reshape_fully_connected_nc_f32(op0_.get(), batch, threadpool);
167184
} else if (op_type_ == OpComputeType::op_compute_type_fp16) {
168-
status = xnn_reshape_fully_connected_nc_f16(op0_.get(), a->Shape()[0], threadpool);
185+
status = xnn_reshape_fully_connected_nc_f16(op0_.get(), batch, threadpool);
169186
}
170187

171188
if (status != xnn_status_success) {

onnxruntime/test/providers/cpu/math/matmul_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ std::vector<MatMulTestData<T>> GenerateTestCases() {
6161
{3, 2, 1, 2},
6262
real_expected_vals({2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55})});
6363

64+
test_cases.push_back(
65+
{"test padding and broadcast A > B - no broadcast in B",
66+
{2, 2, 3, 2},
67+
{2, 1},
68+
{2, 2, 3, 1},
69+
real_expected_vals({1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23})});
70+
6471
test_cases.push_back(
6572
{"test padding and broadcast B > A",
6673
{2, 3, 2},

0 commit comments

Comments
 (0)