Skip to content

Commit e6b0036

Browse files
committed
Expose splitting and accumulation strategies in MEX interface
1 parent 8dddb38 commit e6b0036

File tree

1 file changed

+84
-18
lines changed

1 file changed

+84
-18
lines changed

mex/gemmi.cpp

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,26 @@
22
#include "mexAdapter.hpp"
33
#include "../include/gemmi.hpp"
44

5+
typedef struct {
6+
splittingStrategy splitType;
7+
accumulationStrategy accType;
8+
} algorithmOptions;
9+
static std::unique_ptr<algorithmOptions> options = nullptr;
10+
511
class MexFunction : public matlab::mex::Function {
612
public:
713
void operator()(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
814

15+
if (options == nullptr) {
16+
options = std::make_unique<algorithmOptions>();
17+
options->splitType = splittingStrategy::roundToNearest;
18+
options->accType = accumulationStrategy::integer;
19+
}
20+
921
// Validate input.
1022
validateInput(inputs);
1123

12-
size_t numSplitsA = std::move(inputs[2][0]);
24+
size_t numSplitsA = std::move(inputs[2][0]);
1325
size_t numSplitsB = inputs.size() == 3 ? numSplitsA : std::move(inputs[3][0]);
1426

1527
if (inputs[0].getType() == matlab::data::ArrayType::DOUBLE &&
@@ -39,8 +51,8 @@ class MexFunction : public matlab::mex::Function {
3951
auto A_size = Amatlab.getDimensions();
4052
auto B_size = Bmatlab.getDimensions();
4153

42-
auto C = gemmi<double, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1],
43-
numSplitsA, numSplitsB);
54+
auto C = gemmi<double, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1],
55+
numSplitsA, numSplitsB, options->splitType, options->accType);
4456

4557
matlab::data::ArrayFactory factory;
4658
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -55,7 +67,7 @@ class MexFunction : public matlab::mex::Function {
5567
auto B_size = Bmatlab.getDimensions();
5668

5769
auto C = gemmi<float, int8_t, int32_t>(A, B, A_size[0], A_size[1], B_size[1],
58-
numSplitsA, numSplitsB);
70+
numSplitsA, numSplitsB, options->splitType, options->accType);
5971

6072
matlab::data::ArrayFactory factory;
6173
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
@@ -67,21 +79,9 @@ class MexFunction : public matlab::mex::Function {
6779

6880
size_t numArgs = inputs.size();
6981

70-
if ( numArgs < 3 || numArgs > 4) {
82+
if ( numArgs < 3 || numArgs > 5) {
7183
matlabPtr->feval(u"error",
72-
0, std::vector<matlab::data::Array>({ factory.createScalar("Three or four inputs expected.") }));
73-
}
74-
75-
if (inputs[2].getNumberOfElements() != 1 || std::round((double)inputs[2][0][0]) != (double)inputs[2][0][0]) {
76-
matlabPtr->feval(u"error",
77-
0, std::vector<matlab::data::Array>({ factory.createScalar("The third input must be a scalar integer.") }));
78-
}
79-
80-
if (numArgs == 4) {
81-
if (inputs[3].getNumberOfElements() != 1 || std::round((double)inputs[3][0][0]) != (double)inputs[3][0][0]) {
82-
matlabPtr->feval(u"error",
83-
0, std::vector<matlab::data::Array>({ factory.createScalar("The third input must be a scalar integer.") }));
84-
}
84+
0, std::vector<matlab::data::Array>({ factory.createScalar("Three to five inputs expected.") }));
8585
}
8686

8787
auto A_size = inputs[0].getDimensions();
@@ -113,5 +113,71 @@ class MexFunction : public matlab::mex::Function {
113113
matlabPtr->feval(u"error",
114114
0, std::vector<matlab::data::Array>({ factory.createScalar("Input matrices must have the same data type.") }));
115115
}
116+
117+
if (inputs[2].getNumberOfElements() != 1 || std::round((double)inputs[2][0][0]) != (double)inputs[2][0][0]) {
118+
matlabPtr->feval(u"error",
119+
0, std::vector<matlab::data::Array>({ factory.createScalar("The third input must be a scalar integer.") }));
120+
}
121+
122+
if (numArgs == 4) {
123+
if (inputs[3].getNumberOfElements() != 1 || std::round((double)inputs[3][0][0]) != (double)inputs[3][0][0]) {
124+
matlabPtr->feval(u"error",
125+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fourth input must be a scalar integer.") }));
126+
}
127+
}
128+
129+
if (numArgs == 5) {
130+
if(!inputs[4].isEmpty() && !(inputs[4].getType() == matlab::data::ArrayType::STRUCT)) {
131+
matlabPtr->feval(u"error",
132+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input must be a struct.") }));
133+
}
134+
matlab::data::StructArray inStruct(inputs[4]);
135+
if (inStruct.getNumberOfFields() > 2) {
136+
matlabPtr->feval(u"error",
137+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input must have at most two fields.") }));
138+
}
139+
auto fields = inStruct.getFieldNames();
140+
std::vector<matlab::data::MATLABFieldIdentifier> fieldNames(fields.begin(), fields.end());
141+
for (auto field : fieldNames) {
142+
if (std::string(field) != "split" && std::string(field) != "acc") {
143+
matlabPtr->feval(u"error",
144+
0, std::vector<matlab::data::Array>({ factory.createScalar("The fifth input's fields can only be named 'split' or 'acc'.") }));
145+
} else {
146+
if (inStruct[0][field].getNumberOfElements() != 1 || inStruct[0][field].getType() != matlab::data::ArrayType::CHAR) {
147+
matlabPtr->feval(u"error",
148+
0, std::vector<matlab::data::Array>({ factory.createScalar("The field of the struct should be single characters.") }));
149+
}
150+
const matlab::data::TypedArray<char16_t> data = inStruct[0][field];
151+
if (std::string(field) == "split") {
152+
switch (data[0]) {
153+
case 'n':
154+
options->splitType = splittingStrategy::roundToNearest;
155+
break;
156+
case 'b':
157+
options->splitType = splittingStrategy::bitMasking;
158+
break;
159+
default:
160+
matlabPtr->feval(u"error",
161+
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'split' is invalid.") }));
162+
break;
163+
}
164+
} else if (std::string(field) == "acc") {
165+
switch (data[0]) {
166+
case 'f':
167+
options->accType = accumulationStrategy::floatingPoint;
168+
break;
169+
case 'i':
170+
options->accType = accumulationStrategy::integer;
171+
break;
172+
default:
173+
matlabPtr->feval(u"error",
174+
0, std::vector<matlab::data::Array>({ factory.createScalar("Specified 'acc' is invalid.") }));
175+
break;
176+
}
177+
}
178+
}
179+
}
180+
}
181+
116182
}
117183
};

0 commit comments

Comments
 (0)