Skip to content

Commit e468828

Browse files
committed
Let MEX interface return algorithm information
1 parent eb8548a commit e468828

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

mex/gemmi.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class MexFunction : public matlab::mex::Function {
1919
}
2020

2121
// Validate input.
22-
validateInput(inputs);
22+
validateInput(inputs, outputs);
2323

24-
size_t numSplitsA = std::move(inputs[2][0]);
24+
size_t numSplitsA = std::move(inputs[2][0]);
2525
size_t numSplitsB = inputs.size() == 3 ? numSplitsA : std::move(inputs[3][0]);
2626

2727
if (inputs[0].getType() == matlab::data::ArrayType::DOUBLE &&
@@ -40,6 +40,14 @@ class MexFunction : public matlab::mex::Function {
4040
matlabPtr->feval(u"error",
4141
0, std::vector<matlab::data::Array>({ factory.createScalar("Unsupported combination of data type.") }));
4242
}
43+
44+
if (outputs.size() == 2) {
45+
matlab::data::ArrayFactory factory;
46+
matlab::data::StructArray S = factory.createStructArray({1, 1}, {"split", "acc"});
47+
S[0]["split"] = factory.createCharArray(options->splitType == splittingStrategy::roundToNearest ? "n" : "b");
48+
S[0]["acc"] = factory.createCharArray(options->accType == accumulationStrategy::floatingPoint ? "f" : "i");
49+
outputs[1] = std::move(S);
50+
}
4351
}
4452

4553
private:
@@ -73,10 +81,15 @@ class MexFunction : public matlab::mex::Function {
7381
return factory.createArray({A_size[0], B_size[1]}, C.begin(), C.end());;
7482
}
7583

76-
void validateInput(matlab::mex::ArgumentList inputs) {
84+
void validateInput(matlab::mex::ArgumentList inputs, matlab::mex::ArgumentList outputs) {
7785
std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
7886
matlab::data::ArrayFactory factory;
7987

88+
if (outputs.size() < 1 || outputs.size() > 2) {
89+
matlabPtr->feval(u"error",
90+
0, std::vector<matlab::data::Array>({ factory.createScalar("This function requires one or two output arguments.") }));
91+
}
92+
8093
size_t numArgs = inputs.size();
8194

8295
if ( numArgs < 3 || numArgs > 5) {

mex/gemmi.m

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
% GEMMI Compute matrix product using integer Ozaki scheme.
2-
% C = GEMMI(A,B,ASPLITS,BSPLITS,ALGORITHM) computes the product A*B
3-
% using the Ozaki scheme with ASPLITS and BSPLITS slices for the
4-
% matrices A and B, respectively. The ALGORITHM parameter must be
5-
% a struct, with the following fields currently supported.
2+
% [C, ALGOUT] = GEMMI(A,B,ASPLITS,BSPLITS,ALGIN) computes the matrix
3+
% C = A*B using the Ozaki scheme with ASPLITS and BSPLITS slices
4+
% for the matrices A and B, respectively. The ALGIN parameter
5+
% must be a struct, with the following fields currently supported.
66
% 'split' - selects the stragegy to be used to split A and B into
77
% slices. Possible values are 'b' for bitmasking and 'n
8-
% for round-to-nearest (default).
8+
% for round-to-nearest (default).
99
% 'acc' - selects how the exact integer matrix products are
1010
% accumulated. Possible values are 'f' for floating-point
1111
% arithmetic and 'i' for integer accumulation (default).
12+
% The output paramater ALGOUT is a struct with the same fields as
13+
% ALGIN, which contains the values used in the computation.
1214
%
13-
% C = GEMMI(A,B,ASPLITS,BSPLITS) uses the last ALGORITHM parameter
14-
% in the most recent call to GEMMI, or the default values if no
15-
% previous call was made.
15+
% [...] = GEMMI(A,B,ASPLITS,BSPLITS) uses the ALGIN parameter passed
16+
% the most recent call to GEMMI, or the default values if no previous
17+
% call was made.
1618
%
17-
% C = GEMMI(A,B,SPLITS) uses SPLITS slices for both A and B.
19+
% [...] = GEMMI(A,B,SPLITS) uses SPLITS slices for both A and B.
1820
%
1921
% The splits are stored as 8-bit signed integer, the dot products are
2022
% performed using 32-bit signed arithmetic, and the final accumulation

0 commit comments

Comments
 (0)