Skip to content

Commit d47a676

Browse files
committed
Make notation consistent with manuscript
1 parent cda65ec commit d47a676

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

include/gemmi.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,
212212
size_t iBlock, size_t jBlock) {
213213
for (size_t i = 0; i < A.m; i++) {
214214
for (size_t j = 0; j < B.n; j++) {
215-
for (size_t k = 0; k < A.n; k++) {
216-
C[i + j * A.m] += A.memory[i + k * A.m + iBlock * A.m * A.n] *
217-
B.memory[k + j * B.m + jBlock * B.m * B.n];
215+
for (size_t ell = 0; ell < A.n; ell++) {
216+
C[i + j * A.m] += A.memory[i + ell * A.m + iBlock * A.m * A.n] *
217+
B.memory[ell + j * B.m + jBlock * B.m * B.n];
218218
}
219219
}
220220
}
@@ -312,13 +312,13 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
312312
}
313313

314314
/* Compute matrix vector product C += A * B, where:
315-
* + A is m x p
316-
* + B is p x n
315+
* + A is m x k
316+
* + B is k x n
317317
* + C is m x n
318318
*/
319319
template <typename fp_t, typename splitint_t, typename accumulator_t>
320320
std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
321-
const size_t m, const size_t p, const size_t n,
321+
const size_t m, const size_t k, const size_t n,
322322
const size_t numSplitsA, const size_t numSplitsB,
323323
const splittingStrategy splitType = splittingStrategy::roundToNearest,
324324
const multiplicationStrategy multType = multiplicationStrategy::reduced,
@@ -330,8 +330,8 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
330330
const size_t alpha = std::floor((bitsInAccumulator - log2(n)) / 2);
331331
const size_t bitsPerSlice = std::min(bitsPerInteger, static_cast<size_t>(alpha));
332332

333-
auto splitA = MatrixSplit<splitint_t, fp_t>(A, m, p, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
334-
auto splitB = MatrixSplit<splitint_t, fp_t>(B, p, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
333+
auto splitA = MatrixSplit<splitint_t, fp_t>(A, m, k, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
334+
auto splitB = MatrixSplit<splitint_t, fp_t>(B, k, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
335335

336336
size_t numDiagonals;
337337
switch (multType) {
@@ -361,6 +361,6 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
361361

362362
template <typename fp_t, typename splitint_t, typename accumulator_t>
363363
std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
364-
const size_t m, const size_t p, const size_t n, const size_t numSplits) {
365-
return gemmi <fp_t, splitint_t, accumulator_t> (A, B, m, p, n, numSplits, numSplits);
364+
const size_t m, const size_t k, const size_t n, const size_t numSplits) {
365+
return gemmi <fp_t, splitint_t, accumulator_t> (A, B, m, k, n, numSplits, numSplits);
366366
}

tests/tests.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ void runTest() {
1919
for (size_t numSplitA : { 1, 2, 10 }) {
2020
for (size_t numSplitB : { 1, 2, 10 }) {
2121
for (size_t m = 10; m <= 50; m += 10) {
22-
for (size_t p = 10; p <= 50; p += 10) {
22+
for (size_t k = 10; k <= 50; k += 10) {
2323
for (size_t n = 10; n <= 50; n += 10) {
24-
std::vector<fp_t> A(m * p);
25-
std::vector<fp_t> B(p * n);
24+
std::vector<fp_t> A(m * k);
25+
std::vector<fp_t> B(k * n);
2626

2727
// Initalize matrix with random values.
2828
std::default_random_engine generator(std::random_device{}());
@@ -32,8 +32,8 @@ void runTest() {
3232
for (auto & element : B)
3333
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
3434

35-
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB, splitType, multiplicationType, accumulationType);
36-
auto C_ref = reference_gemm(A, B, m, p, n);
35+
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, k, n, numSplitA, numSplitB, splitType, multiplicationType, accumulationType);
36+
auto C_ref = reference_gemm(A, B, m, k, n);
3737

3838
double relative_error = frobenius_norm<fp_t, double>(C - C_ref) / frobenius_norm<fp_t, double>(C);
3939

0 commit comments

Comments
 (0)