Skip to content

Commit a14e98e

Browse files
committed
Add support for multiple multiplication strategies
1 parent ed60e72 commit a14e98e

File tree

4 files changed

+90
-51
lines changed

4 files changed

+90
-51
lines changed

include/gemmi.hpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@ enum class normalisationDimension {
2424
byCols // Matrix on the right of the product.
2525
};
2626

27-
enum class splittingStrategy{
27+
enum class splittingStrategy {
2828
roundToNearest,
2929
bitMasking
3030
};
3131

32-
enum class accumulationStrategy{
32+
enum class accumulationStrategy {
3333
floatingPoint,
3434
integer
3535
};
3636

37+
enum class multiplicationStrategy {
38+
full,
39+
reduced
40+
};
41+
3742
template <typename splitint_t, typename fp_t>
3843
struct MatrixSplit {
3944
size_t m;
@@ -263,14 +268,13 @@ fp_t computeScalingConstantforUsingSplittingStrategy(const MatrixSplit<splitint_
263268
template <typename splitint_t, typename accumulator_t, typename fp_t>
264269
std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit<splitint_t, fp_t> &A,
265270
const MatrixSplit<splitint_t, fp_t> &B,
266-
const size_t bitsPerSlice) {
271+
const size_t bitsPerSlice,
272+
const size_t numDiagonals) {
267273

268274
std::vector<fp_t > C (A.m * B.n);
269275

270276
auto scalingConstant = computeScalingConstantforUsingSplittingStrategy(A, B);
271277

272-
// Products below the main anti-diagonal are ignored.
273-
size_t numDiagonals = std::max(A.numSplits, B.numSplits) - 1;
274278
for (size_t diagonal = 0; diagonal <= numDiagonals; diagonal++) {
275279
int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1;
276280
size_t Bindex = diagonal > A.numSplits - 1 ? diagonal - A.numSplits + 1 : 0;
@@ -304,14 +308,13 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
304308
template <typename splitint_t, typename accumulator_t, typename fp_t>
305309
std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<splitint_t, fp_t> &A,
306310
const MatrixSplit<splitint_t, fp_t> &B,
307-
const size_t bitsPerSlice) {
311+
const size_t bitsPerSlice,
312+
const size_t numDiagonals) {
308313

309314
std::vector<fp_t > C (A.m * B.n);
310315

311316
auto scalingConstant = computeScalingConstantforUsingSplittingStrategy(A, B);
312317

313-
// Products below the main anti-diagonal are ignored.
314-
size_t numDiagonals = std::max(A.numSplits, B.numSplits) - 1;
315318
for (size_t diagonal = 0; diagonal <= numDiagonals; diagonal++) {
316319
int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1;
317320
size_t Bindex = diagonal > A.numSplits - 1 ? diagonal - A.numSplits + 1 : 0;
@@ -343,6 +346,7 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
343346
const size_t m, const size_t p, const size_t n,
344347
const size_t numSplitsA, const size_t numSplitsB,
345348
const splittingStrategy splitType = splittingStrategy::roundToNearest,
349+
const multiplicationStrategy multType = multiplicationStrategy::reduced,
346350
const accumulationStrategy accType = accumulationStrategy::floatingPoint) {
347351

348352
const size_t bitsInAccumulator = std::numeric_limits<accumulator_t>::digits;
@@ -354,11 +358,26 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
354358
auto splitA = MatrixSplit<splitint_t, fp_t>(A, m, p, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
355359
auto splitB = MatrixSplit<splitint_t, fp_t>(B, p, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
356360

361+
size_t numDiagonals;
362+
switch (multType) {
363+
case multiplicationStrategy::reduced:
364+
// Products below the main anti-diagonal are ignored.
365+
numDiagonals = std::max(splitA.numSplits, splitB.numSplits) - 1;
366+
break;
367+
case multiplicationStrategy::full:
368+
// All products are computed.
369+
numDiagonals = splitA.numSplits + splitB.numSplits - 1;
370+
break;
371+
default:
372+
std::cerr << "Unknown multiplication strategy requested.";
373+
exit(1);
374+
}
375+
357376
switch (accType) {
358377
case accumulationStrategy::floatingPoint:
359-
return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
378+
return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice, numDiagonals);
360379
case accumulationStrategy::integer:
361-
return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
380+
return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice, numDiagonals);
362381
default:
363382
std::cerr << "Unknown accumulation strategy requested.";
364383
exit(1);

mex/gemmi.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
typedef struct {
66
splittingStrategy splitType;
7+
multiplicationStrategy multType;
78
accumulationStrategy accType;
89
} algorithmOptions;
910
static std::unique_ptr<algorithmOptions> options = nullptr;
@@ -45,6 +46,7 @@ class MexFunction : public matlab::mex::Function {
4546
matlab::data::ArrayFactory factory;
4647
matlab::data::StructArray S = factory.createStructArray({1, 1}, {"split", "acc"});
4748
S[0]["split"] = factory.createCharArray(options->splitType == splittingStrategy::roundToNearest ? "n" : "b");
49+
S[0]["mult"] = factory.createCharArray(options->multType == multiplicationStrategy::full ? "f" : "r");
4850
S[0]["acc"] = factory.createCharArray(options->accType == accumulationStrategy::floatingPoint ? "f" : "i");
4951
outputs[1] = std::move(S);
5052
}
@@ -59,8 +61,8 @@ class MexFunction : public matlab::mex::Function {
5961
auto A_size = Amatlab.getDimensions();
6062
auto B_size = Bmatlab.getDimensions();
6163

62-
auto C = gemmi<double, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1],
63-
numSplitsA, numSplitsB, options->splitType, options->accType);
64+
auto C = gemmi<double, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1], numSplitsA, numSplitsB,
65+
options->splitType, options->multType, options->accType);
6466

6567
matlab::data::ArrayFactory factory;
6668
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -74,8 +76,8 @@ class MexFunction : public matlab::mex::Function {
7476
auto A_size = Amatlab.getDimensions();
7577
auto B_size = Bmatlab.getDimensions();
7678

77-
auto C = gemmi<float, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1],
78-
numSplitsA, numSplitsB, options->splitType, options->accType);
79+
auto C = gemmi<float, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1], numSplitsA, numSplitsB,
80+
options->splitType, options->multType, options->accType);
7981

8082
matlab::data::ArrayFactory factory;
8183
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -145,22 +147,22 @@ class MexFunction : public matlab::mex::Function {
145147
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input must be a struct.") }));
146148
}
147149
matlab::data::StructArray inStruct(inputs[4]);
148-
if (inStruct.getNumberOfFields() > 2) {
150+
if (inStruct.getNumberOfFields() > 3) {
149151
matlabPtr->feval(u"error",
150-
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input must have at most two fields.") }));
152+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input must have at most three fields.") }));
151153
}
152154
auto fields = inStruct.getFieldNames();
153155
std::vector<matlab::data::MATLABFieldIdentifier> fieldNames(fields.begin(), fields.end());
154156
for (auto field : fieldNames) {
155-
if (std::string(field) != "split" && std::string(field) != "acc") {
157+
if (std::string(field) != "split" && std::string(field) != "mult" && std::string(field) != "acc") {
156158
matlabPtr->feval(u"error",
157-
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input's fields can only be named 'split' or 'acc'.") }));
159+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input's fields can only be named 'split', 'mult', or 'acc'.") }));
158160
} else {
159161
if (inStruct[0][field].getNumberOfElements() != 1 || inStruct[0][field].getType() != matlab::data::ArrayType::CHAR) {
160162
matlabPtr->feval(u"error",
161-
0, std::vector<matlab::data::Array>({ factory.createScalar("The field of the struct should be single characters.") }));
163+
0, std::vector<matlab::data::Array>({ factory.createScalar("Each field of the struct should be a single character.") }));
162164
}
163-
const matlab::data::TypedArray<char16_t> data = inStruct[0][field];
165+
const matlab::data::TypedArrayRef<char16_t> data = inStruct[0][field];
164166
if (std::string(field) == "split") {
165167
switch ((char)data[0]) {
166168
case 'n':
@@ -174,6 +176,19 @@ class MexFunction : public matlab::mex::Function {
174176
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'split' is invalid.") }));
175177
break;
176178
}
179+
} else if (std::string(field) == "mult") {
180+
switch ((char)(data[0])) {
181+
case 'f':
182+
options->multType = multiplicationStrategy::full;
183+
break;
184+
case 'r':
185+
options->multType = multiplicationStrategy::reduced;
186+
break;
187+
default:
188+
matlabPtr->feval(u"error",
189+
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'mult' is invalid.") }));
190+
break;
191+
}
177192
} else if (std::string(field) == "acc") {
178193
switch ((char)data[0]) {
179194
case 'f':

mex/gemmi.m

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
% GEMMI Compute matrix product using integer Ozaki scheme.
22
% [C, ALGOUT] = GEMMI(A,B,ASPLITS,BSPLITS,ALGIN) computes the matrix
3-
% C = A*B using the Ozaki scheme with ASPLITS and BSPLITS slices
4-
% for the matrices A and B, respectively. The ALGIN parameter
3+
% C = A*B using the Ozaki scheme with ASPLITS and BSPLITS slices
4+
% for the matrices A and B, respectively.The ALGIN parameter
55
% must be a struct, with the following fields currently supported.
66
% 'split' - selects the stragegy to be used to split A and B into
77
% slices. Possible values are 'b' for bitmasking and 'n'
8-
% for round-to-nearest (default).
8+
% for round-to-nearest (default).
9+
% 'mult' - selects how many integer multiplications the algorithm
10+
% will perform in order to compute the result. Possible
11+
% values are 'a' for all ASPLIT * BSPLIT products and 'r'
12+
% for a reduced number (default).
913
% 'acc' - selects how the exact integer matrix products are
1014
% accumulated. Possible values are 'f' for floating-point
1115
% arithmetic and 'i' for integer accumulation (default).
1216
% The output paramater ALGOUT is a struct with the same fields as
1317
% ALGIN, which contains the values used in the computation.
14-
%
18+
%
1519
% [...] = GEMMI(A,B,ASPLITS,BSPLITS) uses the ALGIN parameter passed
1620
% the most recent call to GEMMI, or the default values if no previous
1721
% call was made.
1822
%
1923
% [...] = GEMMI(A,B,SPLITS) uses SPLITS slices for both A and B.
20-
%
24+
%
2125
% The splits are stored as 8-bit signed integer, the dot products are
2226
% performed using 32-bit signed arithmetic, and the final accumulation
2327
% uses either the same format as the matrices A and B (if 'acc' is 'f')
2428
% or 32-bit arithmetic (if 'acc' is 'i').
2529
%
2630
% The matrices A and B must be conformable, and multiplication by a
27-
% scalar is not supported.
31+
% scalar is not supported.

tests/tests.cpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,32 @@ template <> double tolerance<double>() {return 1e-15;}
1313

1414
template <typename fp_t>
1515
void runTest() {
16-
// Test different sizes
17-
for (auto splitType : {splittingStrategy::bitMasking,splittingStrategy::roundToNearest}) {
18-
for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
19-
for (size_t numSplitA : { 1, 2, 10 }) {
20-
for (size_t numSplitB : { 1, 2, 10 }) {
21-
for (size_t m = 10; m <= 50; m += 10) {
22-
for (size_t p = 10; p <= 50; p += 10) {
23-
for (size_t n = 10; n <= 50; n += 10) {
24-
std::vector<fp_t> A(m * p);
25-
std::vector<fp_t> B(p * n);
26-
27-
// Initalize matrix with random values.
28-
std::default_random_engine generator(std::random_device{}());
29-
std::uniform_real_distribution<fp_t> distribution(-100000.0, 100000.0);
30-
for (auto & element : A)
31-
element = numSplitA < 10 ? ldexp(1.0, 2 * numSplitA) - 1 : distribution(generator);
32-
for (auto & element : B)
33-
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
34-
35-
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB, splitType, accumulationType);
36-
auto C_ref = reference_gemm(A, B, m, p, n);
37-
38-
double relative_error = frobenius_norm<fp_t, double>(C - C_ref) / frobenius_norm<fp_t, double>(C);
39-
40-
REQUIRE(relative_error < tolerance<fp_t>());
16+
for (auto splitType : {splittingStrategy::bitMasking, splittingStrategy::roundToNearest}) {
17+
for (auto multiplicationType : {multiplicationStrategy::reduced, multiplicationStrategy::full}) {
18+
for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
19+
for (size_t numSplitA : { 1, 2, 10 }) {
20+
for (size_t numSplitB : { 1, 2, 10 }) {
21+
for (size_t m = 10; m <= 50; m += 10) {
22+
for (size_t p = 10; p <= 50; p += 10) {
23+
for (size_t n = 10; n <= 50; n += 10) {
24+
std::vector<fp_t> A(m * p);
25+
std::vector<fp_t> B(p * n);
26+
27+
// Initalize matrix with random values.
28+
std::default_random_engine generator(std::random_device{}());
29+
std::uniform_real_distribution<fp_t> distribution(-100000.0, 100000.0);
30+
for (auto & element : A)
31+
element = numSplitA < 10 ? ldexp(1.0, 2 * numSplitA) - 1 : distribution(generator);
32+
for (auto & element : B)
33+
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
34+
35+
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB, splitType, multiplicationType, accumulationType);
36+
auto C_ref = reference_gemm(A, B, m, p, n);
37+
38+
double relative_error = frobenius_norm<fp_t, double>(C - C_ref) / frobenius_norm<fp_t, double>(C);
39+
40+
REQUIRE(relative_error < tolerance<fp_t>());
41+
}
4142
}
4243
}
4344
}

0 commit comments

Comments
 (0)