Skip to content

Commit eba69df

Browse files
committed
feat: MatMul kernel for Dense <- Dense, Sparse
1 parent b950ee2 commit eba69df

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

src/runtime/local/kernels/MatMul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ void MatMul<DenseMatrix<VT>, DenseMatrix<VT>, DenseMatrix<VT>>::apply(DenseMatri
254254
const auto nr2 = static_cast<int>(transb ? rhs->getNumCols() : rhs->getNumRows());
255255
const auto nc2 = static_cast<int>(transb ? rhs->getNumRows() : rhs->getNumCols());
256256
if (nc1 != nr2) {
257-
throw std::runtime_error("MatMul - #cols of lhs and #rows of rhs must be the same");
257+
throw std::runtime_error("MatMul - #cols of lhs and #rows of rhs must be the same, but got" +
258+
std::to_string(nc1) + "and" + std::to_string(nr2));
258259
}
259260
const VT alpha = 1.0f;
260261
const VT beta = 0.0f;

src/runtime/local/kernels/MatMul.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,56 @@ template <typename VT> struct MatMul<DenseMatrix<VT>, CSRMatrix<VT>, DenseMatrix
9393
}
9494
};
9595

96+
// ----------------------------------------------------------------------------
97+
// DenseMatrix <- DenseMatrix, CSRMatrix
98+
// ----------------------------------------------------------------------------
99+
100+
template <typename VT> struct MatMul<DenseMatrix<VT>, DenseMatrix<VT>, CSRMatrix<VT>> {
101+
static void apply(DenseMatrix<VT> *&res, const DenseMatrix<VT> *lhs, const CSRMatrix<VT> *rhs, bool transa,
102+
bool transb, DCTX(ctx)) {
103+
const size_t nr1 = lhs->getNumRows();
104+
const size_t nc1 = lhs->getNumCols();
105+
const size_t nr2 = rhs->getNumRows();
106+
const size_t nc2 = rhs->getNumCols();
107+
108+
if (nc1 != nr2) {
109+
throw std::runtime_error("MatMul - #cols of lhs and #rows of rhs must be the same");
110+
}
111+
// FIXME: transpose isn't supported atm
112+
113+
if (res == nullptr)
114+
res = DataObjectFactory::create<DenseMatrix<VT>>(nr1, nc2, /*zero=*/true);
115+
116+
const VT *valuesLhs = lhs->getValues();
117+
VT *valuesRes = res->getValues();
118+
119+
const size_t rowSkipLhs = lhs->getRowSkip();
120+
const size_t rowSkipRes = res->getRowSkip();
121+
122+
// For each row m of lhs
123+
for (size_t m = 0; m < nr1; m++) {
124+
const size_t rowIdxLhs = m * rowSkipLhs;
125+
const size_t rowIdxRes = m * rowSkipRes;
126+
127+
// For each row n of rhs
128+
for (size_t n = 0; n < nr2; n++) {
129+
const VT lhsVal = valuesLhs[rowIdxLhs + n];
130+
131+
// Get non-zeros in row k of rhs
132+
const size_t rowNumNonZeros = rhs->getNumNonZeros(n);
133+
const size_t *rowColIdxs = rhs->getColIdxs(n);
134+
const VT *rowValues = rhs->getValues(n);
135+
136+
// For each non-zero in row k of rhs
137+
for (size_t i = 0; i < rowNumNonZeros; i++) {
138+
const size_t c = rowColIdxs[i];
139+
valuesRes[rowIdxRes + c] += lhsVal * rowValues[i];
140+
}
141+
}
142+
}
143+
}
144+
};
145+
96146
// ----------------------------------------------------------------------------
97147
// Matrix <- Matrix, Matrix
98148
// ----------------------------------------------------------------------------

src/runtime/local/kernels/kernels.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3695,6 +3695,11 @@
36953695
["CSRMatrix", "double"],
36963696
["DenseMatrix", "double"]
36973697
],
3698+
[
3699+
["DenseMatrix", "double"],
3700+
["DenseMatrix", "double"],
3701+
["CSRMatrix", "double"]
3702+
],
36983703
[
36993704
["CSRMatrix", "double"],
37003705
["CSRMatrix", "double"],

test/runtime/local/kernels/MatMulTest.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ void checkMatMul(const DT *lhs, const DT *rhs, const DT *exp, DCTX(dctx), bool t
4141
DataObjectFactory::destroy(res);
4242
}
4343

44+
template <class DTRes, class DTLhs, class DTRhs>
45+
void checkMatMulMixed(const DTLhs *lhs, const DTRhs *rhs, const DTRes *exp, DCTX(dctx), bool transa = false,
46+
bool transb = false) {
47+
DTRes *res = nullptr;
48+
matMul<DTRes, DTLhs, DTRhs>(res, lhs, rhs, transa, transb, dctx);
49+
CHECK(*res == *exp);
50+
DataObjectFactory::destroy(res);
51+
}
52+
4453
TEMPLATE_PRODUCT_TEST_CASE("MatMul", TAG_KERNELS, (CSRMatrix, DATA_TYPES), (VALUE_TYPES)) {
4554
auto dctx = setupContextAndLogger();
4655

@@ -143,6 +152,25 @@ TEMPLATE_PRODUCT_TEST_CASE("MatMul", TAG_KERNELS, (CSRMatrix, DATA_TYPES), (VALU
143152
DataObjectFactory::destroy(m0, m1, m2, m3, m4, m5, m6, v0, v1, v2, v3, v4, v5, v6, v7, v8);
144153
}
145154

155+
TEMPLATE_TEST_CASE("MatMul Dense x Sparse", TAG_KERNELS, VALUE_TYPES) {
156+
using VT = TestType;
157+
using LhsDT = DenseMatrix<VT>;
158+
using RhsDT = CSRMatrix<VT>;
159+
using ResDT = DenseMatrix<VT>;
160+
161+
auto dctx = setupContextAndLogger();
162+
163+
// clang-format off
164+
auto lhs = genGivenVals<LhsDT>(3, { 1, 0, 2, 3, 4, 0, });
165+
auto rhs = genGivenVals<RhsDT>(2, { 0, 5, 0, 0, 6, 0, });
166+
auto exp = genGivenVals<ResDT>(3, { 0, 5, 0, 0, 28, 0, 0, 20, 0, });
167+
// clang-format on
168+
169+
checkMatMulMixed<ResDT, LhsDT, RhsDT>(lhs, rhs, exp, dctx.get());
170+
171+
DataObjectFactory::destroy(lhs, rhs, exp);
172+
}
173+
146174
TEMPLATE_PRODUCT_TEST_CASE("MatMul Transposed", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES)) {
147175
using DT = TestType;
148176
auto dctx = setupContextAndLogger();

0 commit comments

Comments
 (0)