@@ -212,9 +212,9 @@ void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,
212212 size_t iBlock, size_t jBlock) {
213213 for (size_t i = 0 ; i < A.m ; i++) {
214214 for (size_t j = 0 ; j < B.n ; j++) {
215- for (size_t k = 0 ; k < A.n ; k ++) {
216- C[i + j * A.m ] += A.memory [i + k * A.m + iBlock * A.m * A.n ] *
217- B.memory [k + j * B.m + jBlock * B.m * B.n ];
215+ for (size_t ell = 0 ; ell < A.n ; ell ++) {
216+ C[i + j * A.m ] += A.memory [i + ell * A.m + iBlock * A.m * A.n ] *
217+ B.memory [ell + j * B.m + jBlock * B.m * B.n ];
218218 }
219219 }
220220 }
@@ -312,13 +312,13 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
312312}
313313
314314/* Compute matrix vector product C += A * B, where:
315- * + A is m x p
316- * + B is p x n
315+ * + A is m x k
316+ * + B is k x n
317317 * + C is m x n
318318 */
319319template <typename fp_t , typename splitint_t , typename accumulator_t >
320320std::vector<fp_t > gemmi (const std::vector<fp_t > &A, const std::vector<fp_t > &B,
321- const size_t m, const size_t p , const size_t n,
321+ const size_t m, const size_t k , const size_t n,
322322 const size_t numSplitsA, const size_t numSplitsB,
323323 const splittingStrategy splitType = splittingStrategy::roundToNearest,
324324 const multiplicationStrategy multType = multiplicationStrategy::reduced,
@@ -330,8 +330,8 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
330330 const size_t alpha = std::floor ((bitsInAccumulator - log2 (n)) / 2 );
331331 const size_t bitsPerSlice = std::min (bitsPerInteger, static_cast <size_t >(alpha));
332332
333- auto splitA = MatrixSplit<splitint_t , fp_t >(A, m, p , splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
334- auto splitB = MatrixSplit<splitint_t , fp_t >(B, p , n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
333+ auto splitA = MatrixSplit<splitint_t , fp_t >(A, m, k , splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
334+ auto splitB = MatrixSplit<splitint_t , fp_t >(B, k , n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
335335
336336 size_t numDiagonals;
337337 switch (multType) {
@@ -361,6 +361,6 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
361361
362362template <typename fp_t , typename splitint_t , typename accumulator_t >
363363std::vector<fp_t > gemmi (const std::vector<fp_t > &A, const std::vector<fp_t > &B,
364- const size_t m, const size_t p , const size_t n, const size_t numSplits) {
365- return gemmi <fp_t , splitint_t , accumulator_t > (A, B, m, p , n, numSplits, numSplits);
364+ const size_t m, const size_t k , const size_t n, const size_t numSplits) {
365+ return gemmi <fp_t , splitint_t , accumulator_t > (A, B, m, k , n, numSplits, numSplits);
366366}
0 commit comments