@@ -20,24 +20,28 @@ template <> struct get_storage_format<double> {using storage_format = uint64_t;}
2020
2121/* Everything is defined to use column-major. */
2222enum class normalisationDimension {
23- byRows,
24- byCols
23+ byRows, // Matrix on the left of the prodcut.
24+ byCols // Matrix on the right of the product.
2525};
2626
27- /* Splitting format. */
28- enum class splitStrategy {
29- undef,
27+ enum class splittingStrategy {
3028 roundToNearest,
3129 bitMasking
3230};
3331
32+ enum class accumulationStrategy {
33+ floatingPoint,
34+ integer
35+ };
36+
3437template <typename splitint_t , typename fp_t >
3538struct MatrixSplit {
3639 size_t m;
3740 size_t n;
41+ splittingStrategy splitType;
3842 size_t numSplits;
43+ size_t bitsPerSlice;
3944 normalisationDimension dimension;
40- splitStrategy splitType;
4145
4246 std::vector<fp_t > matrix;
4347 std::vector<splitint_t > memory;
@@ -46,25 +50,23 @@ struct MatrixSplit {
4650
4751 using uint_t = typename get_storage_format<fp_t >::storage_format;
4852
49- MatrixSplit (const size_t m, const size_t n, const size_t numSplits,
50- const normalisationDimension dimension,
51- const std::vector<fp_t >& matrix,
52- std::vector<splitint_t >& memory,
53- std::vector<fp_t >& powersVector,
54- std::vector<int >& scalingExponents) :
55- m (m), n(n), numSplits(numSplits), dimension(dimension),
56- splitType (splitStrategy::undef),
57- matrix (matrix), memory(memory),
58- powersVector (powersVector),
59- scalingExponents (scalingExponents) {}
60-
61- MatrixSplit (const size_t m, const size_t n, const size_t numSplits,
62- const normalisationDimension dimension, const std::vector<fp_t >& matrix) :
63- m (m), n(n), numSplits(numSplits), dimension(dimension), splitType(splitStrategy::undef),
64- matrix (matrix) {
53+ MatrixSplit (const std::vector<fp_t >& matrix, const size_t m, const size_t n,
54+ const splittingStrategy splitType, size_t numSplits, size_t bitsPerSlice,
55+ const normalisationDimension dimension) :
56+ m (m), n(n), splitType(splitType), numSplits(numSplits), bitsPerSlice(bitsPerSlice),
57+ dimension (dimension), matrix(matrix) {
6558 this ->memory .resize (m * n * numSplits);
6659 this ->powersVector .resize (this ->otherDimension ());
6760 this ->scalingExponents .resize (this ->otherDimension ());
61+ this ->computeNormalisationVectors ();
62+ switch (splitType) {
63+ case splittingStrategy::roundToNearest:
64+ this ->computeSplitsWithRoundToNearest ();
65+ break ;
66+ case splittingStrategy::bitMasking:
67+ this ->computeSplitsWithBitMasking ();
68+ break ;
69+ }
6870 }
6971
7072 // This is the dimension alng which the inner product is computed.
@@ -121,27 +123,27 @@ struct MatrixSplit {
121123 * Integer products are accumulated in integer arithmetic along the diagonal, and in
122124 * floating-point arithmetic across diagonals.
123125 */
124- void computeSplitsWithRoundToNearest (const size_t bitsPerSlice ) {
125- this ->splitType = splitStrategy ::roundToNearest;
126+ void computeSplitsWithRoundToNearest () {
127+ this ->splitType = splittingStrategy ::roundToNearest;
126128 auto iStride = this ->iStride ();
127129 auto jStride = this ->jStride ();
128130 auto localMatrix = this ->matrix ;
129131 for (size_t slice = 0 ; slice < numSplits; slice++) {
130132 for (size_t i = 0 ; i < this ->otherDimension (); i++) {
131- fp_t sigma = ldexp (0.75 , numFracBits<fp_t >() - bitsPerSlice * slice + 1 - bitsPerSlice) * powersVector[i];
133+ fp_t sigma = ldexp (0.75 , numFracBits<fp_t >() - this -> bitsPerSlice * slice + 1 - this -> bitsPerSlice ) * powersVector[i];
132134 for (size_t j = 0 ; j < this ->innerProductDimension (); j++) {
133135 auto value = (localMatrix[i * iStride + j * jStride] + sigma);
134136 value -= sigma;
135137 localMatrix[i * iStride + j * jStride] -= value;
136- value = value / powersVector[i] * ldexp (1.0 , bitsPerSlice * slice + bitsPerSlice - 1 );
138+ value = value / powersVector[i] * ldexp (1.0 , this -> bitsPerSlice * slice + this -> bitsPerSlice - 1 );
137139 this ->memory [i * iStride + j * jStride + slice * this ->matrix .size ()] = value;
138140 }
139141 }
140142 }
141143 }
142144
143- void computeSplitsWithBitMasking (const size_t bitsPerSlice ) {
144- this ->splitType = splitStrategy ::bitMasking;
145+ void computeSplitsWithBitMasking () {
146+ this ->splitType = splittingStrategy ::bitMasking;
145147 // Compute splits one row/column at a time.
146148 auto nunExpBits = numExpBits<fp_t >();
147149 auto nunFracBits = numFracBits<fp_t >();
@@ -164,16 +166,16 @@ struct MatrixSplit {
164166 }
165167
166168 // Create bitmask.
167- const uint_t smallBitmask = (1 << bitsPerSlice) - 1 ;
169+ const uint_t smallBitmask = (1 << this -> bitsPerSlice ) - 1 ;
168170 // Perform the split.
169171 for (size_t j = 0 ; j < this ->innerProductDimension (); j++) {
170- int16_t shiftCounter = nunFracBits - bitsPerSlice;
172+ int16_t shiftCounter = nunFracBits - this -> bitsPerSlice ;
171173 int currentExponent;
172174 frexp (this ->matrix [i * iStride + j * jStride], ¤tExponent);
173175 int16_t exponentDifference = scalingExponents[i] - currentExponent;
174176 for (size_t slice = 0 ; slice < numSplits; slice++) {
175- if (exponentDifference > (signed )bitsPerSlice) {
176- exponentDifference -= bitsPerSlice;
177+ if (exponentDifference > (signed )this -> bitsPerSlice ) {
178+ exponentDifference -= this -> bitsPerSlice ;
177179 } else {
178180 shiftCounter += exponentDifference;
179181 exponentDifference = 0 ;
@@ -186,28 +188,14 @@ struct MatrixSplit {
186188 currentSlice << -shiftCounter;
187189 splitint_t value = (splitint_t )(current_split) * (sign[j] ? -1 : 1 );
188190 this ->memory [i * iStride + j * jStride + slice * this ->matrix .size ()] = value;
189- shiftCounter -= bitsPerSlice;
191+ shiftCounter -= this -> bitsPerSlice ;
190192 }
191193 }
192194 }
193195 }
194196 }
195197};
196198
197- template <typename splitint_t , typename fp_t >
198- MatrixSplit<splitint_t , fp_t > splitFloatToInt (const std::vector<fp_t > A,
199- const size_t m, const size_t n,
200- normalisationDimension dimension,
201- const size_t numSplits,
202- const size_t bitsPerSlice) {
203- auto splits = MatrixSplit<splitint_t , fp_t >(m, n, numSplits, dimension, A);
204- splits.computeNormalisationVectors ();
205- splits.computeSplitsWithRoundToNearest (bitsPerSlice);
206- // splits.computeSplitsWithBitMasking(bitsPerSlice);
207-
208- return splits;
209- }
210-
211199template <typename splitint_t , typename accumulator_t , typename fp_t >
212200std::vector<fp_t > mergeIntToFloats (const MatrixSplit<splitint_t , fp_t > &A,
213201 const size_t bitsPerSlice) {
@@ -255,8 +243,8 @@ fp_t computeScalingConstantforUsingSplitStrategy(const MatrixSplit<splitint_t, f
255243 // When splitting with round-to-nearst, the first slice has bitsPerSlice - 1 bits, and we need
256244 // to account for this when scaling the final result.
257245 fp_t scalingConstant = 1.0 ;
258- scalingConstant *= A.splitType == splitStrategy ::roundToNearest ? 2.0 : 1.0 ;
259- scalingConstant *= B.splitType == splitStrategy ::roundToNearest ? 2.0 : 1.0 ;
246+ scalingConstant *= A.splitType == splittingStrategy ::roundToNearest ? 2.0 : 1.0 ;
247+ scalingConstant *= B.splitType == splittingStrategy ::roundToNearest ? 2.0 : 1.0 ;
260248 return scalingConstant;
261249}
262250
@@ -351,7 +339,9 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
351339template <typename fp_t , typename splitint_t , typename accumulator_t >
352340std::vector<fp_t > gemmi (const std::vector<fp_t > &A, const std::vector<fp_t > &B,
353341 const size_t m, const size_t p, const size_t n,
354- const size_t numSplitsA, const size_t numSplitsB) {
342+ const size_t numSplitsA, const size_t numSplitsB,
343+ const splittingStrategy splitType = splittingStrategy::roundToNearest,
344+ const accumulationStrategy accType = accumulationStrategy::floatingPoint) {
355345
356346 const size_t bitsInAccumulator = std::numeric_limits<accumulator_t >::digits;
357347 const size_t bitsPerInteger = std::numeric_limits<splitint_t >::digits;
@@ -360,15 +350,16 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
360350 const size_t bitsPerSlice = std::min (bitsPerInteger, static_cast <size_t >(alpha));
361351
362352 // TODO: The user should be able to select what splitting strategy to use.
363- auto splitA = splitFloatToInt<splitint_t , fp_t >
364- (A, m, p, normalisationDimension::byRows, numSplitsA, bitsPerSlice);
365-
366- auto splitB = splitFloatToInt<splitint_t , fp_t >
367- (B, p, n, normalisationDimension::byCols, numSplitsB, bitsPerSlice);
353+ auto splitA = MatrixSplit<splitint_t , fp_t >(A, m, p, splitType, numSplitsA, bitsPerSlice, normalisationDimension::byRows);
354+ auto splitB = MatrixSplit<splitint_t , fp_t >(B, p, n, splitType, numSplitsB, bitsPerSlice, normalisationDimension::byCols);
368355
369356 // TODO: The user should be able to select what accumulation strategy to use.
370- // return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
371- return computeProductsWithIntegerAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
357+ switch (accType) {
358+ case accumulationStrategy::floatingPoint:
359+ return computeProductsWithFloatingPointAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
360+ case accumulationStrategy::integer:
361+ return computeProductsWithIntegerAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
362+ }
372363}
373364template <typename fp_t , typename splitint_t , typename accumulator_t >
374365std::vector<fp_t > gemmi (const std::vector<fp_t > &A, const std::vector<fp_t > &B,
0 commit comments