Skip to content

Commit 057225e

Browse files
committed
Swap accType and multType parameters in gemmi
1 parent 9768d0a commit 057225e

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

include/gemmi.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
321321
const size_t m, const size_t k, const size_t n,
322322
const size_t numSplitsA, const size_t numSplitsB,
323323
const splittingStrategy splitType = splittingStrategy::roundToNearest,
324-
const multiplicationStrategy multType = multiplicationStrategy::reduced,
325-
const accumulationStrategy accType = accumulationStrategy::floatingPoint) {
324+
const accumulationStrategy accType = accumulationStrategy::floatingPoint,
325+
const multiplicationStrategy multType = multiplicationStrategy::reduced) {
326326

327327
const size_t bitsInAccumulator = std::numeric_limits<accumulator_t>::digits;
328328
const size_t bitsPerInteger = std::numeric_limits<splitint_t>::digits;

mex/gemmi.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "mexAdapter.hpp"
33
#include "../include/gemmi.hpp"
44

5+
/*
6+
*/
57
typedef struct {
68
splittingStrategy splitType;
79
multiplicationStrategy multType;
@@ -17,6 +19,7 @@ class MexFunction : public matlab::mex::Function {
1719
options = std::make_unique<algorithmOptions>();
1820
options->splitType = splittingStrategy::roundToNearest;
1921
options->accType = accumulationStrategy::integer;
22+
options->multType = multiplicationStrategy::reduced;
2023
}
2124

2225
// Validate input.
@@ -44,10 +47,10 @@ class MexFunction : public matlab::mex::Function {
4447

4548
if (outputs.size() == 2) {
4649
matlab::data::ArrayFactory factory;
47-
matlab::data::StructArray S = factory.createStructArray({1, 1}, {"split", "acc"});
50+
matlab::data::StructArray S = factory.createStructArray({1, 1}, {"split", "acc", "mult"});
4851
S[0]["split"] = factory.createCharArray(options->splitType == splittingStrategy::roundToNearest ? "n" : "b");
49-
S[0]["mult"] = factory.createCharArray(options->multType == multiplicationStrategy::full ? "f" : "r");
5052
S[0]["acc"] = factory.createCharArray(options->accType == accumulationStrategy::floatingPoint ? "f" : "i");
53+
S[0]["mult"] = factory.createCharArray(options->multType == multiplicationStrategy::full ? "f" : "r");
5154
outputs[1] = std::move(S);
5255
}
5356
}
@@ -62,7 +65,7 @@ class MexFunction : public matlab::mex::Function {
6265
auto B_size = Bmatlab.getDimensions();
6366

6467
auto C = gemmi<double, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1], numSplitsA, numSplitsB,
65-
options->splitType, options->multType, options->accType);
68+
options->splitType, options->accType, options->multType);
6669

6770
matlab::data::ArrayFactory factory;
6871
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -77,7 +80,7 @@ class MexFunction : public matlab::mex::Function {
7780
auto B_size = Bmatlab.getDimensions();
7881

7982
auto C = gemmi<float, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1], numSplitsA, numSplitsB,
80-
options->splitType, options->multType, options->accType);
83+
options->splitType, options->accType, options->multType);
8184

8285
matlab::data::ArrayFactory factory;
8386
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -176,30 +179,30 @@ class MexFunction : public matlab::mex::Function {
176179
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'split' is invalid.") }));
177180
break;
178181
}
179-
} else if (std::string(field) == "mult") {
180-
switch ((char)(data[0])) {
182+
} else if (std::string(field) == "acc") {
183+
switch ((char)data[0]) {
181184
case 'f':
182-
options->multType = multiplicationStrategy::full;
185+
options->accType = accumulationStrategy::floatingPoint;
183186
break;
184-
case 'r':
185-
options->multType = multiplicationStrategy::reduced;
187+
case 'i':
188+
options->accType = accumulationStrategy::integer;
186189
break;
187190
default:
188191
matlabPtr->feval(u"error",
189-
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'mult' is invalid.") }));
192+
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'acc' is invalid.") }));
190193
break;
191194
}
192-
} else if (std::string(field) == "acc") {
193-
switch ((char)data[0]) {
195+
} else if (std::string(field) == "mult") {
196+
switch ((char)(data[0])) {
194197
case 'f':
195-
options->accType = accumulationStrategy::floatingPoint;
198+
options->multType = multiplicationStrategy::full;
196199
break;
197-
case 'i':
198-
options->accType = accumulationStrategy::integer;
200+
case 'r':
201+
options->multType = multiplicationStrategy::reduced;
199202
break;
200203
default:
201204
matlabPtr->feval(u"error",
202-
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'acc' is invalid.") }));
205+
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'mult' is invalid.") }));
203206
break;
204207
}
205208
}

tests/tests.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ template <> double tolerance<double>() {return 1e-15;}
1414
template <typename fp_t>
1515
void runTest() {
1616
for (auto splitType : {splittingStrategy::bitMasking, splittingStrategy::roundToNearest}) {
17-
for (auto multiplicationType : {multiplicationStrategy::reduced, multiplicationStrategy::full}) {
18-
for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
17+
for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
18+
for (auto multiplicationType : {multiplicationStrategy::reduced, multiplicationStrategy::full}) {
1919
for (size_t numSplitA : { 1, 2, 10 }) {
2020
for (size_t numSplitB : { 1, 2, 10 }) {
2121
for (size_t m = 10; m <= 50; m += 10) {
@@ -32,7 +32,7 @@ void runTest() {
3232
for (auto & element : B)
3333
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
3434

35-
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, k, n, numSplitA, numSplitB, splitType, multiplicationType, accumulationType);
35+
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, k, n, numSplitA, numSplitB, splitType, accumulationType, multiplicationType);
3636
auto C_ref = reference_gemm(A, B, m, k, n);
3737

3838
double relative_error = frobenius_norm<fp_t, double>(C - C_ref) / frobenius_norm<fp_t, double>(C);

0 commit comments

Comments
 (0)