44
55typedef struct {
66 splittingStrategy splitType;
7+ multiplicationStrategy multType;
78 accumulationStrategy accType;
89} algorithmOptions;
910static std::unique_ptr<algorithmOptions> options = nullptr ;
@@ -45,6 +46,7 @@ class MexFunction : public matlab::mex::Function {
4546 matlab::data::ArrayFactory factory;
4647 matlab::data::StructArray S = factory.createStructArray ({1 , 1 }, {" split" , " acc" });
4748 S[0 ][" split" ] = factory.createCharArray (options->splitType == splittingStrategy::roundToNearest ? " n" : " b" );
49+ S[0 ][" mult" ] = factory.createCharArray (options->multType == multiplicationStrategy::full ? " f" : " r" );
4850 S[0 ][" acc" ] = factory.createCharArray (options->accType == accumulationStrategy::floatingPoint ? " f" : " i" );
4951 outputs[1 ] = std::move (S);
5052 }
@@ -59,8 +61,8 @@ class MexFunction : public matlab::mex::Function {
5961 auto A_size = Amatlab.getDimensions ();
6062 auto B_size = Bmatlab.getDimensions ();
6163
62- auto C = gemmi<double , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ],
63- numSplitsA, numSplitsB, options->splitType , options->accType );
64+ auto C = gemmi<double , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ], numSplitsA, numSplitsB,
65+ options-> splitType , options->multType , options->accType );
6466
6567 matlab::data::ArrayFactory factory;
6668 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -74,8 +76,8 @@ class MexFunction : public matlab::mex::Function {
7476 auto A_size = Amatlab.getDimensions ();
7577 auto B_size = Bmatlab.getDimensions ();
7678
77- auto C = gemmi<float , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ],
78- numSplitsA, numSplitsB, options->splitType , options->accType );
79+ auto C = gemmi<float , int8_t , int32_t >(A, B, A_size[0 ], A_size[1 ], B_size[1 ], numSplitsA, numSplitsB,
80+ options-> splitType , options->multType , options->accType );
7981
8082 matlab::data::ArrayFactory factory;
8183 return factory.createArray ({A_size[0 ], B_size[1 ]}, C.begin (), C.end ());;
@@ -145,22 +147,22 @@ class MexFunction : public matlab::mex::Function {
145147 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input must be a struct." ) }));
146148 }
147149 matlab::data::StructArray inStruct (inputs[4 ]);
148- if (inStruct.getNumberOfFields () > 2 ) {
150+ if (inStruct.getNumberOfFields () > 3 ) {
149151 matlabPtr->feval (u" error" ,
150- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input must have at most two fields." ) }));
152+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input must have at most three fields." ) }));
151153 }
152154 auto fields = inStruct.getFieldNames ();
153155 std::vector<matlab::data::MATLABFieldIdentifier> fieldNames (fields.begin (), fields.end ());
154156 for (auto field : fieldNames) {
155- if (std::string (field) != " split" && std::string (field) != " acc" ) {
157+ if (std::string (field) != " split" && std::string (field) != " mult " && std::string (field) != " acc" ) {
156158 matlabPtr->feval (u" error" ,
157- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input's fields can only be named 'split' or 'acc'." ) }));
159+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The fifth input's fields can only be named 'split', 'mult', or 'acc'." ) }));
158160 } else {
159161 if (inStruct[0 ][field].getNumberOfElements () != 1 || inStruct[0 ][field].getType () != matlab::data::ArrayType::CHAR) {
160162 matlabPtr->feval (u" error" ,
161- 0 , std::vector<matlab::data::Array>({ factory.createScalar (" The field of the struct should be single characters ." ) }));
163+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Each field of the struct should be a single character ." ) }));
162164 }
163- const matlab::data::TypedArray <char16_t > data = inStruct[0 ][field];
165+ const matlab::data::TypedArrayRef <char16_t > data = inStruct[0 ][field];
164166 if (std::string (field) == " split" ) {
165167 switch ((char )data[0 ]) {
166168 case ' n' :
@@ -174,6 +176,19 @@ class MexFunction : public matlab::mex::Function {
174176 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'split' is invalid." ) }));
175177 break ;
176178 }
179+ } else if (std::string (field) == " mult" ) {
180+ switch ((char )(data[0 ])) {
181+ case ' f' :
182+ options->multType = multiplicationStrategy::full;
183+ break ;
184+ case ' r' :
185+ options->multType = multiplicationStrategy::reduced;
186+ break ;
187+ default :
188+ matlabPtr->feval (u" error" ,
189+ 0 , std::vector<matlab::data::Array>({ factory.createScalar (" Specified 'mult' is invalid." ) }));
190+ break ;
191+ }
177192 } else if (std::string (field) == " acc" ) {
178193 switch ((char )data[0 ]) {
179194 case ' f' :
0 commit comments