22#include " mexAdapter.hpp"
33#include " ../include/gemmi.hpp"
44
5+ typedef struct {
6+ splittingStrategy splitType;
7+ accumulationStrategy accType;
8+ } algorithmOptions;
9+ static std::unique_ptr<algorithmOptions> options = nullptr ;
10+
511class MexFunction : public matlab ::mex::Function {
612public:
713 void operator ()(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
814
15+ if (options == nullptr ) {
16+ options = std::make_unique<algorithmOptions>();
17+ options->splitType = splittingStrategy::roundToNearest;
18+ options->accType = accumulationStrategy::integer;
19+ }
20+
921 // Validate input.
1022 validateInput (inputs);
1123
12- size_t numSplitsA = std::move (inputs[2 ][0 ]);
24+ size_t numSplitsA = std::move (inputs[2 ][0 ]);
1325 size_t numSplitsB = inputs.size () == 3 ? numSplitsA : std::move (inputs[3 ][0 ]);
1426
1527 if (inputs[0 ].getType () == matlab::data::ArrayType::DOUBLE &&
@@ -39,8 +51,8 @@ class MexFunction : public matlab::mex::Function {
3951 auto A_size = Amatlab.getDimensions ();
4052 auto B_size = Bmatlab.getDimensions ();
4153
42- auto C = gemmi<double , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ],
43- numSplitsA, numSplitsB);
54+ auto C = gemmi<double , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ],
55+ numSplitsA, numSplitsB, options-> splitType , options-> accType );
4456
4557 matlab::data::ArrayFactory factory;
4658 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -55,7 +67,7 @@ class MexFunction : public matlab::mex::Function {
5567 auto B_size = Bmatlab.getDimensions ();
5668
5769 auto C = gemmi<float , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ],
58- numSplitsA, numSplitsB);
70+ numSplitsA, numSplitsB, options-> splitType , options-> accType );
5971
6072 matlab::data::ArrayFactory factory;
6173 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -67,21 +79,9 @@ class MexFunction : public matlab::mex::Function {
6779
6880 size_t numArgs = inputs.size ();
6981
70- if ( numArgs < 3 || numArgs > 4 ) {
82+ if ( numArgs < 3 || numArgs > 5 ) {
7183 matlabPtr->feval (u" error" ,
72- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Three or four inputs expected." ) }));
73- }
74-
75- if (inputs[2 ].getNumberOfElements () != 1 || std::round ((double )inputs[2 ][0 ][0 ]) != (double )inputs[2 ][0 ][0 ]) {
76- matlabPtr->feval (u" error" ,
77- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The third input must be a scalar integer." ) }));
78- }
79-
80- if (numArgs == 4 ) {
81- if (inputs[3 ].getNumberOfElements () != 1 || std::round ((double )inputs[3 ][0 ][0 ]) != (double )inputs[3 ][0 ][0 ]) {
82- matlabPtr->feval (u" error" ,
83- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The third input must be a scalar integer." ) }));
84- }
84+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Three to five inputs expected." ) }));
8585 }
8686
8787 auto A_size = inputs[0 ].getDimensions ();
@@ -113,5 +113,71 @@ class MexFunction : public matlab::mex::Function {
113113 matlabPtr->feval (u" error" ,
114114 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Input matrices must have the same data type." ) }));
115115 }
116+
117+ if (inputs[2 ].getNumberOfElements () != 1 || std::round ((double )inputs[2 ][0 ][0 ]) != (double )inputs[2 ][0 ][0 ]) {
118+ matlabPtr->feval (u" error" ,
119+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The third input must be a scalar integer." ) }));
120+ }
121+
122+ if (numArgs == 4 ) {
123+ if (inputs[3 ].getNumberOfElements () != 1 || std::round ((double )inputs[3 ][0 ][0 ]) != (double )inputs[3 ][0 ][0 ]) {
124+ matlabPtr->feval (u" error" ,
125+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fourth input must be a scalar integer." ) }));
126+ }
127+ }
128+
129+ if (numArgs == 5 ) {
130+ if (!inputs[4 ].isEmpty () && !(inputs[4 ].getType () == matlab::data::ArrayType::STRUCT)) {
131+ matlabPtr->feval (u" error" ,
132+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input must be a struct." ) }));
133+ }
134+ matlab::data::StructArray inStruct (inputs[4 ]);
135+ if (inStruct.getNumberOfFields () > 2 ) {
136+ matlabPtr->feval (u" error" ,
137+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input must have at most two fields." ) }));
138+ }
139+ auto fields = inStruct.getFieldNames ();
140+ std::vector<matlab::data::MATLABFieldIdentifier> fieldNames (fields.begin (), fields.end ());
141+ for (auto field : fieldNames) {
142+ if (std::string (field) != " split" && std::string (field) != " acc" ) {
143+ matlabPtr->feval (u" error" ,
144+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input's fields can only be named 'split' or 'acc'." ) }));
145+ } else {
146+ if (inStruct[0 ][field].getNumberOfElements () != 1 || inStruct[0 ][field].getType () != matlab::data::ArrayType::CHAR) {
147+ matlabPtr->feval (u" error" ,
148+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The field of the struct should be single characters." ) }));
149+ }
150+ const matlab::data::TypedArray<char16_t > data = inStruct[0 ][field];
151+ if (std::string (field) == " split" ) {
152+ switch (data[0 ]) {
153+ case ' n' :
154+ options->splitType = splittingStrategy::roundToNearest;
155+ break ;
156+ case ' b' :
157+ options->splitType = splittingStrategy::bitMasking;
158+ break ;
159+ default :
160+ matlabPtr->feval (u" error" ,
161+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'split' is invalid." ) }));
162+ break ;
163+ }
164+ } else if (std::string (field) == " acc" ) {
165+ switch (data[0 ]) {
166+ case ' f' :
167+ options->accType = accumulationStrategy::floatingPoint;
168+ break ;
169+ case ' i' :
170+ options->accType = accumulationStrategy::integer;
171+ break ;
172+ default :
173+ matlabPtr->feval (u" error" ,
174+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'acc' is invalid." ) }));
175+ break ;
176+ }
177+ }
178+ }
179+ }
180+ }
181+
116182 }
117183};
0 commit comments