@@ -70,9 +70,6 @@ struct MatrixSplit {
7070 }
7171
7272 // This is the dimension alng which the inner product is computed.
73- // This will be the number of columns for the matrix on the left of the
74- // product, which is normalised by rows, and the number of rows for the
75- // matrix on the right, which is normalised by columns.
7673 size_t innerProductDimension () {
7774 return (dimension == normalisationDimension::byRows) ? n : m;
7875 }
@@ -160,7 +157,7 @@ struct MatrixSplit {
160157 sign[j] = std::signbit (value); // Extract sign.
161158 tmp[j] &= (~(uint_t )(0 )) >> (nunExpBits + 1 ); // Remove exponent.
162159 // Restore implicit bit for normal numbers.
163- // TODO : NaNs and infs are currently not supported. .
160+ // NOTE : NaNs and infs are currently not supported.
164161 if (std::fpclassify (value) == FP_NORMAL)
165162 tmp[j] |= ((uint_t )1 << (nunFracBits - 1 ));
166163 }
@@ -238,10 +235,10 @@ void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,
238235
239236/* Compute scaling constant for using the split strategy. */
240237template <typename splitint_t , typename fp_t >
241- fp_t computeScalingConstantforUsingSplitStrategy (const MatrixSplit<splitint_t , fp_t > &A,
238+ fp_t computeScalingConstantforUsingSplittingStrategy (const MatrixSplit<splitint_t , fp_t > &A,
242239 const MatrixSplit<splitint_t , fp_t > &B) {
243- // When splitting with round-to-nearst , the first slice has bitsPerSlice - 1 bits, and we need
244- // to account for this when scaling the final result.
240+ // When splitting with round-to-nearest , the first slice has bitsPerSlice - 1 bits,
241+ // and we need to account for this when scaling the final result.
245242 fp_t scalingConstant = 1.0 ;
246243 scalingConstant *= A.splitType == splittingStrategy::roundToNearest ? 2.0 : 1.0 ;
247244 scalingConstant *= B.splitType == splittingStrategy::roundToNearest ? 2.0 : 1.0 ;
@@ -263,8 +260,9 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
263260
264261 std::vector<fp_t > C (A.m * B.n );
265262
266- auto scalingConstant = computeScalingConstantforUsingSplitStrategy (A, B);
263+ auto scalingConstant = computeScalingConstantforUsingSplittingStrategy (A, B);
267264
265+ // Products below the main anti-diagonal are ignored.
268266 size_t numDiagonals = std::max (A.numSplits , B.numSplits ) - 1 ;
269267 for (size_t diagonal = 0 ; diagonal <= numDiagonals; diagonal++) {
270268 int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1 ;
@@ -303,12 +301,9 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
303301
304302 std::vector<fp_t > C (A.m * B.n );
305303
306- auto scalingConstant = computeScalingConstantforUsingSplitStrategy (A, B);
304+ auto scalingConstant = computeScalingConstantforUsingSplittingStrategy (A, B);
307305
308- // Here, I'm ignoring the products below the main anti-diagonal, as done in the original
309- // paper.
310- // NOTE: this is different from previous work, as I allow a different number of splits
311- // for A and B.
306+ // Products below the main anti-diagonal are ignored.
312307 size_t numDiagonals = std::max (A.numSplits , B.numSplits ) - 1 ;
313308 for (size_t diagonal = 0 ; diagonal <= numDiagonals; diagonal++) {
314309 int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1 ;
@@ -349,11 +344,9 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
349344 const size_t alpha = std::floor ((bitsInAccumulator - log2 (n)) / 2 );
350345 const size_t bitsPerSlice = std::min (bitsPerInteger, static_cast <size_t >(alpha));
351346
352- // TODO: The user should be able to select what splitting strategy to use.
353347 auto splitA = MatrixSplit<splitint_t , fp_t >(A, m, p, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
354348 auto splitB = MatrixSplit<splitint_t , fp_t >(B, p, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
355349
356- // TODO: The user should be able to select what accumulation strategy to use.
357350 switch (accType) {
358351 case accumulationStrategy::floatingPoint:
359352 return computeProductsWithFloatingPointAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
0 commit comments