@@ -24,12 +24,20 @@ enum class normalisationDimension {
2424 byCols
2525};
2626
27+ /* Splitting format. */
28+ enum class splitStrategy {
29+ undef,
30+ roundToNearest,
31+ bitMasking
32+ };
33+
2734template <typename splitint_t , typename fp_t >
2835struct MatrixSplit {
2936 size_t m;
3037 size_t n;
3138 size_t numSplits;
3239 normalisationDimension dimension;
40+ splitStrategy splitType;
3341
3442 std::vector<fp_t > matrix;
3543 std::vector<splitint_t > memory;
@@ -45,13 +53,14 @@ struct MatrixSplit {
4553 std::vector<fp_t >& powersVector,
4654 std::vector<int >& scalingExponents) :
4755 m (m), n(n), numSplits(numSplits), dimension(dimension),
56+ splitType (splitStrategy::undef),
4857 matrix (matrix), memory(memory),
4958 powersVector (powersVector),
5059 scalingExponents (scalingExponents) {}
5160
5261 MatrixSplit (const size_t m, const size_t n, const size_t numSplits,
5362 const normalisationDimension dimension, const std::vector<fp_t >& matrix) :
54- m (m), n(n), numSplits(numSplits), dimension(dimension),
63+ m (m), n(n), numSplits(numSplits), dimension(dimension), splitType(splitStrategy::undef),
5564 matrix (matrix) {
5665 this ->memory .resize (m * n * numSplits);
5766 this ->powersVector .resize (this ->otherDimension ());
@@ -102,7 +111,37 @@ struct MatrixSplit {
102111 }
103112 }
104113
105- void computeSplitsWithTruncation (const size_t bitsPerSlice) {
114+ /* Split the matrix using round-to-nearest. This is an implementation of
115+ * Algorithm 8 in
116+ *
117+ * Uchino Y., Ozaki K., Imamura T. Performance enanchcement of the Ozaki
118+ * scheme on integer matrix multiplication unit. arXiv:2409.13313 [cs.DC]. 2024.
119+ * DOI: 10.48550/arXiv.2409.13313
120+ *
121+ * Integer products are accumulated in integer arithmetic along the diagonal, and in
122+ * floating-point arithmetic across diagonals.
123+ */
124+ void computeSplitsWithRoundToNearest (const size_t bitsPerSlice) {
125+ this ->splitType = splitStrategy::roundToNearest;
126+ auto iStride = this ->iStride ();
127+ auto jStride = this ->jStride ();
128+ auto localMatrix = this ->matrix ;
129+ for (size_t slice = 0 ; slice < numSplits; slice++) {
130+ for (size_t i = 0 ; i < this ->otherDimension (); i++) {
131+ fp_t sigma = ldexp (0.75 , numFracBits<fp_t >() - bitsPerSlice * slice + 1 - bitsPerSlice) * powersVector[i];
132+ for (size_t j = 0 ; j < this ->innerProductDimension (); j++) {
133+ auto value = (localMatrix[i * iStride + j * jStride] + sigma);
134+ value -= sigma;
135+ localMatrix[i * iStride + j * jStride] -= value;
136+ value = value / powersVector[i] * ldexp (1.0 , bitsPerSlice * slice + bitsPerSlice - 1 );
137+ this ->memory [i * iStride + j * jStride + slice * this ->matrix .size ()] = value;
138+ }
139+ }
140+ }
141+ }
142+
143+ void computeSplitsWithBitMasking (const size_t bitsPerSlice) {
144+ this ->splitType = splitStrategy::bitMasking;
106145 // Compute splits one row/column at a time.
107146 auto nunExpBits = numExpBits<fp_t >();
108147 auto nunFracBits = numFracBits<fp_t >();
@@ -163,14 +202,14 @@ MatrixSplit<splitint_t, fp_t> splitFloatToInt(const std::vector<fp_t> A,
163202 const size_t bitsPerSlice) {
164203 auto splits = MatrixSplit<splitint_t , fp_t >(m, n, numSplits, dimension, A);
165204 splits.computeNormalisationVectors ();
166- splits.computeSplitsWithTruncation (bitsPerSlice);
167- // splits.computeSplitsWithTruncation (bitsPerSlice);
205+ splits.computeSplitsWithRoundToNearest (bitsPerSlice);
206+ // splits.computeSplitsWithBitMasking (bitsPerSlice);
168207
169208 return splits;
170209}
171210
172211template <typename splitint_t , typename accumulator_t , typename fp_t >
173- std::vector<fp_t > mergeFloatfromInt (const MatrixSplit<splitint_t , fp_t > &A,
212+ std::vector<fp_t > mergeIntToFloats (const MatrixSplit<splitint_t , fp_t > &A,
174213 const size_t bitsPerSlice) {
175214 std::vector<fp_t > C (A.m * A.n , 0.0 );
176215
@@ -209,6 +248,18 @@ void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,
209248 }
210249}
211250
251+ /* Compute scaling constant for using the split strategy. */
252+ template <typename splitint_t , typename fp_t >
253+ fp_t computeScalingConstantforUsingSplitStrategy (const MatrixSplit<splitint_t , fp_t > &A,
254+ const MatrixSplit<splitint_t , fp_t > &B) {
255+ // When splitting with round-to-nearst, the first slice has bitsPerSlice - 1 bits, and we need
256+ // to account for this when scaling the final result.
257+ fp_t scalingConstant = 1.0 ;
258+ scalingConstant *= A.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0 ;
259+ scalingConstant *= B.splitType == splitStrategy::roundToNearest ? 2.0 : 1.0 ;
260+ return scalingConstant;
261+ }
262+
212263/* Accumulate products using the technique in:
213264 *
214265 * Ootomo H., Ozaki K., Yokota R. DGEMM on integer matrix multiplication
@@ -224,6 +275,8 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
224275
225276 std::vector<fp_t > C (A.m * B.n );
226277
278+ auto scalingConstant = computeScalingConstantforUsingSplitStrategy (A, B);
279+
227280 size_t numDiagonals = std::max (A.numSplits , B.numSplits ) - 1 ;
228281 for (size_t diagonal = 0 ; diagonal <= numDiagonals; diagonal++) {
229282 int Aindex = diagonal < A.numSplits - 1 ? diagonal : A.numSplits - 1 ;
@@ -234,7 +287,7 @@ std::vector<fp_t> computeProductsWithFloatingPointAccumulation(const MatrixSplit
234287 for (size_t i = 0 ; i < A.m ; i++) {
235288 for (size_t j = 0 ; j < B.n ; j++) {
236289 fp_t scaledSum = std::ldexp (accumulator[i + j * A.m ], -(Aindex + 1 + Bindex + 1 ) * bitsPerSlice);
237- fp_t scalingFactor = A.powersVector [i] * B.powersVector [j];
290+ fp_t scalingFactor = A.powersVector [i] * B.powersVector [j] * scalingConstant ;
238291 C[i + j * A.m ] += scaledSum * scalingFactor;
239292 }
240293 }
@@ -262,6 +315,8 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
262315
263316 std::vector<fp_t > C (A.m * B.n );
264317
318+ auto scalingConstant = computeScalingConstantforUsingSplitStrategy (A, B);
319+
265320 // Here, I'm ignoring the products below the main anti-diagonal, as done in the original
266321 // paper.
267322 // NOTE: this is different from previous work, as I allow a different number of splits
@@ -279,7 +334,7 @@ std::vector<fp_t> computeProductsWithIntegerAccumulation(const MatrixSplit<split
279334 for (size_t i = 0 ; i < A.m ; i++) {
280335 for (size_t j = 0 ; j < B.n ; j++) {
281336 fp_t scaledSum = std::ldexp (accumulator[i + j * A.m ], -(diagonal + 2 ) * bitsPerSlice);
282- fp_t scalingFactor = A.powersVector [i] * B.powersVector [j];
337+ fp_t scalingFactor = A.powersVector [i] * B.powersVector [j] * scalingConstant ;
283338 C[i + j * A.m ] += scaledSum * scalingFactor;
284339 }
285340 }
@@ -304,17 +359,19 @@ std::vector<fp_t> gemmi (const std::vector<fp_t> &A, const std::vector<fp_t> &B,
304359 const size_t alpha = std::floor ((bitsInAccumulator - log2 (n)) / 2 );
305360 const size_t bitsPerSlice = std::min (bitsPerInteger, static_cast <size_t >(alpha));
306361
362+ // TODO: The user should be able to select what splitting strategy to use.
307363 auto splitA = splitFloatToInt<splitint_t , fp_t >
308364 (A, m, p, normalisationDimension::byRows, numSplitsA, bitsPerSlice);
309365
310366 auto splitB = splitFloatToInt<splitint_t , fp_t >
311367 (B, p, n, normalisationDimension::byCols, numSplitsB, bitsPerSlice);
312368
313- return computeProductsWithFloatingPointAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
314- // return computeProductsWithIntegerAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
369+ // TODO: The user should be able to select what accumulation strategy to use.
370+ // return computeProductsWithFloatingPointAccumulation<splitint_t, accumulator_t, fp_t>(splitA, splitB, bitsPerSlice);
371+ return computeProductsWithIntegerAccumulation<splitint_t , accumulator_t , fp_t >(splitA, splitB, bitsPerSlice);
315372}
316373template <typename fp_t , typename splitint_t , typename accumulator_t >
317374std::vector<fp_t > gemmi (const std::vector<fp_t > &A, const std::vector<fp_t > &B,
318375 const size_t m, const size_t p, const size_t n, const size_t numSplits) {
319376 return gemmi <fp_t , splitint_t , accumulator_t > (A, B, m, p, n, numSplits, numSplits);
320- }
377+ }
0 commit comments