Skip to content

Commit 19e26c7

Browse files
committed
Add splitting strategy to nearest
1 parent 06df51c commit 19e26c7

File tree

1 file changed

+67
-10
lines changed

1 file changed

+67
-10
lines changed

include/gemmi.hpp

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,20 @@ enum class normalisationDimension {
2424
byCols
2525
};
2626

27+
/* Splitting format. */
28+
enum class splitStrategy{
29+
undef,
30+
roundToNearest,
31+
bitMasking
32+
};
33+
2734
template <typename splitint_t, typename fp_t>
2835
struct MatrixSplit {
2936
size_t m;
3037
size_t n;
3138
size_t numSplits;
3239
normalisationDimension dimension;
40+
splitStrategy splitType;
3341

3442
std::vector<fp_t> matrix;
3543
std::vector<splitint_t> memory;
@@ -45,13 +53,14 @@ struct MatrixSplit {
4553
std::vector<fp_t>& powersVector,
4654
std::vector<int>& scalingExponents) :
4755
m(m), n(n), numSplits(numSplits), dimension(dimension),
56+
splitType(splitStrategy::undef),
4857
matrix(matrix), memory(memory),
4958
powersVector(powersVector),
5059
scalingExponents(scalingExponents) {}
5160

5261
MatrixSplit(const size_t m, const size_t n, const size_t numSplits,
5362
const normalisationDimension dimension, const std::vector<fp_t>& matrix) :
54-
m(m), n(n), numSplits(numSplits), dimension(dimension),
63+
m(m), n(n), numSplits(numSplits), dimension(dimension), splitType(splitStrategy::undef),
5564
matrix(matrix) {
5665
this->memory.resize(m * n * numSplits);
5766
this->powersVector.resize(this->otherDimension());
@@ -102,7 +111,37 @@ struct MatrixSplit {
102111
}
103112
}
104113

105-
void computeSplitsWithTruncation(const size_t bitsPerSlice) {
114+
/* Split the matrix using round-to-nearest. This is an implementation of
115+
* Algorithm 8 in
116+
*
117+
* Uchino Y., Ozaki K., Imamura T. Performance enanchcement of the Ozaki
118+
* scheme on integer matrix multiplication unit. arXiv:2409.13313 [cs.DC]. 2024.
119+
* DOI: 10.48550/arXiv.2409.13313
120+
*
121+
* Integer products are accumulated in integer arithmetic along the diagonal, and in
122+
* floating-point arithmetic across diagonals.
123+
*/
124+
void computeSplitsWithRoundToNearest(const size_t bitsPerSlice) {
125+
this->splitType = splitStrategy::roundToNearest;
126+
auto iStride = this->iStride();
127+
auto jStride = this->jStride();
128+
auto localMatrix = this->matrix;
129+
for (size_t slice = 0; slice < numSplits; slice++) {
130+
for (size_t i = 0; i < this->otherDimension(); i++) {
131+
fp_t sigma = ldexp(0.75, numFracBits<fp_t>() - bitsPerSlice * slice + 1 - bitsPerSlice) * powersVector[i];
132+
for (size_t j = 0; j < this->innerProductDimension(); j++) {
133+
auto value = (localMatrix[i * iStride + j * jStride] + sigma);
134+
value -= sigma;
135+
localMatrix[i * iStride + j * jStride] -= value;
136+
value = value / powersVector[i] * ldexp(1.0, bitsPerSlice * slice + bitsPerSlice - 1);
137+
this->memory[i * iStride + j * jStride + slice * this->matrix.size()] = value;
138+
}
139+
}
140+
}
141+
}
142+
143+
void computeSplitsWithBitMasking(const size_t bitsPerSlice) {
144+
this->splitType = splitStrategy::bitMasking;
106145
// Compute splits one row/column at a time.
107146
auto nunExpBits = numExpBits<fp_t>();
108147
auto nunFracBits = numFracBits<fp_t>();
@@ -163,14 +202,14 @@ MatrixSplit<splitint_t, fp_t> splitFloatToInt(const std::vector<fp_t> A,
163202
const size_t bitsPerSlice) {
164203
auto splits = MatrixSplit<splitint_t, fp_t>(m, n, numSplits, dimension, A);
165204
splits.computeNormalisationVectors();
166-
splits.computeSplitsWithTruncation(bitsPerSlice);
167-
//splits.computeSplitsWithTruncation(bitsPerSlice);
205+
splits.computeSplitsWithRoundToNearest(bitsPerSlice);
206+
// splits.computeSplitsWithBitMasking(bitsPerSlice);
168207

169208
return splits;
170209
}
171210

172211
template <typename splitint_t, typename accumulator_t, typename fp_t>
173-
std::vector<fp_t> mergeFloatfromInt(const MatrixSplit<splitint_t, fp_t> &A,
212+
std::vector<fp_t> mergeIntToFloats(const MatrixSplit<splitint_t, fp_t> &A,
174213
const size_t bitsPerSlice) {
175214
std::vector<fp_t> C (A.m * A.n, 0.0);
176215

@@ -209,6 +248,18 @@ void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,
209248
}
210249
}
211250

251+
/* Compute scaling constant for using the split strategy. */
252+
template <typename splitint_t, typename fp_t>
253+
fp_t computeScalingConstantforUsingSplitStrategy(const MatrixSplit<splitint_t, fp_t> &A,
254+
const MatrixSplit<splitint_t, fp_t> &B) {
255+
// When splitting with round-to-nearst, the first slice has bitsPerSlice - 1 bits, and we need
256+
// to account for this when scaling the final result.
257+
fp_t scalingConstant = 1.0;
258+
scalingConstant *= A.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0;
259+
scalingConstant *= B.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0;
260+
return scalingConstant;
261+
}
262+
212263
/* Accumulate products using the technique in:
213264
*
214265
* Ootomo H., Ozaki K., Yokota R. DGEMM on integer matrix multiplication
@@ -224,6 +275,8 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
224275

225276
std::vector<fp_t > C (A.m * B.n);
226277

278+
auto scalingConstant = computeScalingConstantforUsingSplitStrategy(A, B);
279+
227280
size_t numDiagonals = std::max(A.numSplits, B.numSplits) - 1;
228281
for (size_t diagonal = 0; diagonal <= numDiagonals; diagonal++) {
229282
int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1;
@@ -234,7 +287,7 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
234287
for (size_t i = 0; i < A.m; i++) {
235288
for (size_t j = 0; j < B.n; j++) {
236289
fp_t scaledSum = std::ldexp(accumulator[i + j * A.m], -(Aindex + 1 + Bindex + 1) * bitsPerSlice);
237-
fp_t scalingFactor = A.powersVector[i] * B.powersVector[j];
290+
fp_t scalingFactor = A.powersVector[i] * B.powersVector[j] * scalingConstant;
238291
C[i + j * A.m] += scaledSum * scalingFactor;
239292
}
240293
}
@@ -262,6 +315,8 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
262315

263316
std::vector<fp_t > C (A.m * B.n);
264317

318+
auto scalingConstant = computeScalingConstantforUsingSplitStrategy(A, B);
319+
265320
// Here, I'm ignoring the products below the main anti-diagonal, as done in the original
266321
// paper.
267322
// NOTE: this is different from previous work, as I allow a different number of splits
@@ -279,7 +334,7 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
279334
for (size_t i = 0; i < A.m; i++) {
280335
for (size_t j = 0; j < B.n; j++) {
281336
fp_t scaledSum = std::ldexp(accumulator[i + j * A.m], -(diagonal + 2) * bitsPerSlice);
282-
fp_t scalingFactor = A.powersVector[i] * B.powersVector[j];
337+
fp_t scalingFactor = A.powersVector[i] * B.powersVector[j] * scalingConstant;
283338
C[i + j * A.m] += scaledSum * scalingFactor;
284339
}
285340
}
@@ -304,17 +359,19 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
304359
const size_t alpha = std::floor((bitsInAccumulator - log2(n)) / 2);
305360
const size_t bitsPerSlice = std::min(bitsPerInteger, static_cast<size_t>(alpha));
306361

362+
// TODO: The user should be able to select what splitting strategy to use.
307363
auto splitA = splitFloatToInt<splitint_t, fp_t>
308364
(A, m, p, normalisationDimension::byRows, numSplitsA, bitsPerSlice);
309365

310366
auto splitB = splitFloatToInt<splitint_t, fp_t>
311367
(B, p, n, normalisationDimension::byCols, numSplitsB, bitsPerSlice);
312368

313-
return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
314-
// return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
369+
// TODO: The user should be able to select what accumulation strategy to use.
370+
// return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
371+
return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
315372
}
316373
template <typename fp_t, typename splitint_t, typename accumulator_t>
317374
std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
318375
const size_t m, const size_t p, const size_t n, const size_t numSplits) {
319376
return gemmi <fp_t, splitint_t, accumulator_t> (A, B, m, p, n, numSplits, numSplits);
320-
}
377+
}

0 commit comments

Comments
 (0)