Skip to content

Commit 20fb53b

Browse files
committed
Expose splitting and accumulation strategies in C++ interface
1 parent 2a19849 commit 20fb53b

File tree

1 file changed

+48
-57
lines changed

1 file changed

+48
-57
lines changed

include/gemmi.hpp

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,28 @@ template <> struct get_storage_format<double> {using storage_format = uint64_t;}
2020

2121
/* Everything is defined to use column-major. */
2222
enum class normalisationDimension {
23-
byRows,
24-
byCols
23+
byRows, // Matrix on the left of the prodcut.
24+
byCols // Matrix on the right of the product.
2525
};
2626

27-
/* Splitting format. */
28-
enum class splitStrategy{
29-
undef,
27+
enum class splittingStrategy{
3028
roundToNearest,
3129
bitMasking
3230
};
3331

32+
enum class accumulationStrategy{
33+
floatingPoint,
34+
integer
35+
};
36+
3437
template <typename splitint_t, typename fp_t>
3538
struct MatrixSplit {
3639
size_t m;
3740
size_t n;
41+
splittingStrategy splitType;
3842
size_t numSplits;
43+
size_t bitsPerSlice;
3944
normalisationDimension dimension;
40-
splitStrategy splitType;
4145

4246
std::vector<fp_t> matrix;
4347
std::vector<splitint_t> memory;
@@ -46,25 +50,23 @@ struct MatrixSplit {
4650

4751
using uint_t = typename get_storage_format<fp_t>::storage_format;
4852

49-
MatrixSplit(const size_t m, const size_t n, const size_t numSplits,
50-
const normalisationDimension dimension,
51-
const std::vector<fp_t>& matrix,
52-
std::vector<splitint_t>& memory,
53-
std::vector<fp_t>& powersVector,
54-
std::vector<int>& scalingExponents) :
55-
m(m), n(n), numSplits(numSplits), dimension(dimension),
56-
splitType(splitStrategy::undef),
57-
matrix(matrix), memory(memory),
58-
powersVector(powersVector),
59-
scalingExponents(scalingExponents) {}
60-
61-
MatrixSplit(const size_t m, const size_t n, const size_t numSplits,
62-
const normalisationDimension dimension, const std::vector<fp_t>& matrix) :
63-
m(m), n(n), numSplits(numSplits), dimension(dimension), splitType(splitStrategy::undef),
64-
matrix(matrix) {
53+
MatrixSplit(const std::vector<fp_t>& matrix, const size_t m, const size_t n,
54+
const splittingStrategy splitType, size_t numSplits, size_t bitsPerSlice,
55+
const normalisationDimension dimension) :
56+
m(m), n(n), splitType(splitType), numSplits(numSplits), bitsPerSlice(bitsPerSlice),
57+
dimension(dimension), matrix(matrix) {
6558
this->memory.resize(m * n * numSplits);
6659
this->powersVector.resize(this->otherDimension());
6760
this->scalingExponents.resize(this->otherDimension());
61+
this->computeNormalisationVectors();
62+
switch (splitType) {
63+
case splittingStrategy::roundToNearest:
64+
this->computeSplitsWithRoundToNearest();
65+
break;
66+
case splittingStrategy::bitMasking:
67+
this->computeSplitsWithBitMasking();
68+
break;
69+
}
6870
}
6971

7072
// This is the dimension alng which the inner product is computed.
@@ -121,27 +123,27 @@ struct MatrixSplit {
121123
* Integer products are accumulated in integer arithmetic along the diagonal, and in
122124
* floating-point arithmetic across diagonals.
123125
*/
124-
void computeSplitsWithRoundToNearest(const size_t bitsPerSlice) {
125-
this->splitType = splitStrategy::roundToNearest;
126+
void computeSplitsWithRoundToNearest() {
127+
this->splitType = splittingStrategy::roundToNearest;
126128
auto iStride = this->iStride();
127129
auto jStride = this->jStride();
128130
auto localMatrix = this->matrix;
129131
for (size_t slice = 0; slice < numSplits; slice++) {
130132
for (size_t i = 0; i < this->otherDimension(); i++) {
131-
fp_t sigma = ldexp(0.75, numFracBits<fp_t>() - bitsPerSlice * slice + 1 - bitsPerSlice) * powersVector[i];
133+
fp_t sigma = ldexp(0.75, numFracBits<fp_t>() - this->bitsPerSlice * slice + 1 - this->bitsPerSlice) * powersVector[i];
132134
for (size_t j = 0; j < this->innerProductDimension(); j++) {
133135
auto value = (localMatrix[i * iStride + j * jStride] + sigma);
134136
value -= sigma;
135137
localMatrix[i * iStride + j * jStride] -= value;
136-
value = value / powersVector[i] * ldexp(1.0, bitsPerSlice * slice + bitsPerSlice - 1);
138+
value = value / powersVector[i] * ldexp(1.0, this->bitsPerSlice * slice + this->bitsPerSlice - 1);
137139
this->memory[i * iStride + j * jStride + slice * this->matrix.size()] = value;
138140
}
139141
}
140142
}
141143
}
142144

143-
void computeSplitsWithBitMasking(const size_t bitsPerSlice) {
144-
this->splitType = splitStrategy::bitMasking;
145+
void computeSplitsWithBitMasking() {
146+
this->splitType = splittingStrategy::bitMasking;
145147
// Compute splits one row/column at a time.
146148
auto nunExpBits = numExpBits<fp_t>();
147149
auto nunFracBits = numFracBits<fp_t>();
@@ -164,16 +166,16 @@ struct MatrixSplit {
164166
}
165167

166168
// Create bitmask.
167-
const uint_t smallBitmask = (1 << bitsPerSlice) - 1;
169+
const uint_t smallBitmask = (1 << this->bitsPerSlice) - 1;
168170
// Perform the split.
169171
for (size_t j = 0; j < this->innerProductDimension(); j++) {
170-
int16_t shiftCounter = nunFracBits - bitsPerSlice;
172+
int16_t shiftCounter = nunFracBits - this->bitsPerSlice;
171173
int currentExponent;
172174
frexp(this->matrix[i * iStride + j * jStride], &currentExponent);
173175
int16_t exponentDifference = scalingExponents[i] - currentExponent;
174176
for (size_t slice = 0; slice < numSplits; slice++) {
175-
if (exponentDifference > (signed)bitsPerSlice) {
176-
exponentDifference -= bitsPerSlice;
177+
if (exponentDifference > (signed)this->bitsPerSlice) {
178+
exponentDifference -= this->bitsPerSlice;
177179
} else {
178180
shiftCounter += exponentDifference;
179181
exponentDifference = 0;
@@ -186,28 +188,14 @@ struct MatrixSplit {
186188
currentSlice << -shiftCounter;
187189
splitint_t value = (splitint_t)(current_split) * (sign[j] ? -1 : 1);
188190
this->memory[i * iStride + j * jStride + slice * this->matrix.size()] = value;
189-
shiftCounter -= bitsPerSlice;
191+
shiftCounter -= this->bitsPerSlice;
190192
}
191193
}
192194
}
193195
}
194196
}
195197
};
196198

197-
template <typename splitint_t, typename fp_t>
198-
MatrixSplit<splitint_t, fp_t> splitFloatToInt(const std::vector<fp_t> A,
199-
const size_t m, const size_t n,
200-
normalisationDimension dimension,
201-
const size_t numSplits,
202-
const size_t bitsPerSlice) {
203-
auto splits = MatrixSplit<splitint_t, fp_t>(m, n, numSplits, dimension, A);
204-
splits.computeNormalisationVectors();
205-
splits.computeSplitsWithRoundToNearest(bitsPerSlice);
206-
// splits.computeSplitsWithBitMasking(bitsPerSlice);
207-
208-
return splits;
209-
}
210-
211199
template <typename splitint_t, typename accumulator_t, typename fp_t>
212200
std::vector<fp_t> mergeIntToFloats(const MatrixSplit<splitint_t, fp_t> &A,
213201
const size_t bitsPerSlice) {
@@ -255,8 +243,8 @@ fp_t computeScalingConstantforUsingSplitStrategy(const MatrixSplit<splitint_t, f
255243
// When splitting with round-to-nearst, the first slice has bitsPerSlice - 1 bits, and we need
256244
// to account for this when scaling the final result.
257245
fp_t scalingConstant = 1.0;
258-
scalingConstant *= A.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0;
259-
scalingConstant *= B.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0;
246+
scalingConstant *= A.splitType == splittingStrategy::roundToNearest ? 2.0 : 1.0;
247+
scalingConstant *= B.splitType == splittingStrategy::roundToNearest ? 2.0 : 1.0;
260248
return scalingConstant;
261249
}
262250

@@ -351,7 +339,9 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
351339
template <typename fp_t, typename splitint_t, typename accumulator_t>
352340
std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
353341
const size_t m, const size_t p, const size_t n,
354-
const size_t numSplitsA, const size_t numSplitsB) {
342+
const size_t numSplitsA, const size_t numSplitsB,
343+
const splittingStrategy splitType = splittingStrategy::roundToNearest,
344+
const accumulationStrategy accType = accumulationStrategy::floatingPoint) {
355345

356346
const size_t bitsInAccumulator = std::numeric_limits<accumulator_t>::digits;
357347
const size_t bitsPerInteger = std::numeric_limits<splitint_t>::digits;
@@ -360,15 +350,16 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
360350
const size_t bitsPerSlice = std::min(bitsPerInteger, static_cast<size_t>(alpha));
361351

362352
// TODO: The user should be able to select what splitting strategy to use.
363-
auto splitA = splitFloatToInt<splitint_t, fp_t>
364-
(A, m, p, normalisationDimension::byRows, numSplitsA, bitsPerSlice);
365-
366-
auto splitB = splitFloatToInt<splitint_t, fp_t>
367-
(B, p, n, normalisationDimension::byCols, numSplitsB, bitsPerSlice);
353+
auto splitA = MatrixSplit<splitint_t, fp_t>(A, m, p, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
354+
auto splitB = MatrixSplit<splitint_t, fp_t>(B, p, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
368355

369356
// TODO: The user should be able to select what accumulation strategy to use.
370-
// return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
371-
return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
357+
switch (accType) {
358+
case accumulationStrategy::floatingPoint:
359+
return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
360+
case accumulationStrategy::integer:
361+
return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
362+
}
372363
}
373364
template <typename fp_t, typename splitint_t, typename accumulator_t>
374365
std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,

0 commit comments

Comments
 (0)