22#include " mexAdapter.hpp"
33#include " ../include/gemmi.hpp"
44
5+ /*
6+ */
57typedef struct {
68 splittingStrategy splitType;
79 multiplicationStrategy multType;
@@ -17,6 +19,7 @@ class MexFunction : public matlab::mex::Function {
1719 options = std::make_unique<algorithmOptions>();
1820 options->splitType = splittingStrategy::roundToNearest;
1921 options->accType = accumulationStrategy::integer;
22+ options->multType = multiplicationStrategy::reduced;
2023 }
2124
2225 // Validate input.
@@ -44,10 +47,10 @@ class MexFunction : public matlab::mex::Function {
4447
4548 if (outputs.size () == 2 ) {
4649 matlab::data::ArrayFactory factory;
47- matlab::data::StructArray S = factory.createStructArray ({1 , 1 }, {" split" , " acc" });
50+ matlab::data::StructArray S = factory.createStructArray ({1 , 1 }, {" split" , " acc" , " mult " });
4851 S[0 ][" split" ] = factory.createCharArray (options->splitType == splittingStrategy::roundToNearest ? " n" : " b" );
49- S[0 ][" mult" ] = factory.createCharArray (options->multType == multiplicationStrategy::full ? " f" : " r" );
5052 S[0 ][" acc" ] = factory.createCharArray (options->accType == accumulationStrategy::floatingPoint ? " f" : " i" );
53+ S[0 ][" mult" ] = factory.createCharArray (options->multType == multiplicationStrategy::full ? " f" : " r" );
5154 outputs[1 ] = std::move (S);
5255 }
5356 }
@@ -62,7 +65,7 @@ class MexFunction : public matlab::mex::Function {
6265 auto B_size = Bmatlab.getDimensions ();
6366
6467 auto C = gemmi<double , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ], numSplitsA, numSplitsB,
65- options->splitType , options->multType , options->accType );
68+ options->splitType , options->accType , options->multType );
6669
6770 matlab::data::ArrayFactory factory;
6871 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -77,7 +80,7 @@ class MexFunction : public matlab::mex::Function {
7780 auto B_size = Bmatlab.getDimensions ();
7881
7982 auto C = gemmi<float , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ], numSplitsA, numSplitsB,
80- options->splitType , options->multType , options->accType );
83+ options->splitType , options->accType , options->multType );
8184
8285 matlab::data::ArrayFactory factory;
8386 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -176,30 +179,30 @@ class MexFunction : public matlab::mex::Function {
176179 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'split' is invalid." ) }));
177180 break ;
178181 }
179- } else if (std::string (field) == " mult " ) {
180- switch ((char )( data[0 ]) ) {
182+ } else if (std::string (field) == " acc " ) {
183+ switch ((char )data[0 ]) {
181184 case ' f' :
182- options->multType = multiplicationStrategy::full ;
185+ options->accType = accumulationStrategy::floatingPoint ;
183186 break ;
184- case ' r ' :
185- options->multType = multiplicationStrategy::reduced ;
187+ case ' i ' :
188+ options->accType = accumulationStrategy::integer ;
186189 break ;
187190 default :
188191 matlabPtr->feval (u" error" ,
189- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'mult ' is invalid." ) }));
192+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'acc ' is invalid." ) }));
190193 break ;
191194 }
192- } else if (std::string (field) == " acc " ) {
193- switch ((char )data[0 ]) {
195+ } else if (std::string (field) == " mult " ) {
196+ switch ((char )( data[0 ]) ) {
194197 case ' f' :
195- options->accType = accumulationStrategy::floatingPoint ;
198+ options->multType = multiplicationStrategy::full ;
196199 break ;
197- case ' i ' :
198- options->accType = accumulationStrategy::integer ;
200+ case ' r ' :
201+ options->multType = multiplicationStrategy::reduced ;
199202 break ;
200203 default :
201204 matlabPtr->feval (u" error" ,
202- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'acc ' is invalid." ) }));
205+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'mult ' is invalid." ) }));
203206 break ;
204207 }
205208 }
0 commit comments