2525#include " TritonAMDGPUTransforms/MfmaGroup.h"
2626#include " Utility.h"
2727#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
28+ #include " mlir/Dialect/Utils/IndexingUtils.h"
2829#include " llvm/ADT/TypeSwitch.h"
2930
3031using namespace mlir ;
@@ -281,21 +282,8 @@ struct DotOpMFMAConversionHelper {
281282 auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding ());
282283 auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding ());
283284 int kWidth = aEncoding.getKWidth ();
284- // If the kBase of the selected mfma instruction is larger than
285- // kWidth of the operand, it means the shape is large enough to
286- // use double rated mfma, but we (AccelerateAMDMatmul pass) choose
287- // to use single rated mfma.
288- if (kBase > kWidth ) {
289- int kDimOperandSizeNew = 64 / mDim * kWidth ;
290- maybeMfmaIntrinsic = MfmaIntrinsic::selectFor (
291- mfmaVersion, mDim , nDim, kDimOperandSizeNew , elemTyA, elemTyB,
292- /* withScale=*/ false , allowXF32);
293- if (failed (maybeMfmaIntrinsic))
294- llvm::report_fatal_error (" No match found in MFMA database\n " );
295- }
296285
297286 intrinsicName = maybeMfmaIntrinsic->name ;
298- kBase = maybeMfmaIntrinsic->kBase ;
299287
300288 // If we are using XF32, the kWidth (and kBase) is double that of F32.
301289 if (aTensorTy.getElementType ().isF32 () && allowXF32)
@@ -335,6 +323,7 @@ struct DotOpMFMAConversionHelper {
335323 const int subBlocks =
336324 getNumSubmatrices (aTensorTy.getElementType (), mDim , nDim);
337325 auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
326+ int numVecInKBase = numRepK * kWidth / kBase ;
338327
339328 Value firstMfma;
340329 auto vecTy = vec_ty (dstElemTy, elemsPerVec);
@@ -350,18 +339,14 @@ struct DotOpMFMAConversionHelper {
350339 tb.i32_val (v));
351340 }
352341 acc = zeroAuxiliarBlocks (subBlocks, acc);
353- for (int k = 0 ; k < numRepK; k++) {
354- for (int kPack = 0 ; kPack < kWidth / kBase ; ++kPack ) {
355- acc = mfmaLayout.getIsTransposed ()
356- ? generateMFMAOp (intrinsicName,
357- operandB[kPack ][{b, n, k}],
358- operandA[kPack ][{b, m, k}], acc)
359- : generateMFMAOp (intrinsicName,
360- operandA[kPack ][{b, m, k}],
361- operandB[kPack ][{b, n, k}], acc);
362- if (!firstMfma)
363- firstMfma = acc;
364- }
342+ for (int k = 0 ; k < numVecInKBase; k++) {
343+ acc = mfmaLayout.getIsTransposed ()
344+ ? generateMFMAOp (intrinsicName, operandB[{b, n, k}],
345+ operandA[{b, m, k}], acc)
346+ : generateMFMAOp (intrinsicName, operandA[{b, m, k}],
347+ operandB[{b, n, k}], acc);
348+ if (!firstMfma)
349+ firstMfma = acc;
365350 }
366351 acc = reduceSubBlocks (subBlocks, acc);
367352 adjustAccForSmallKDim (fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
@@ -387,109 +372,120 @@ struct DotOpMFMAConversionHelper {
387372 return success ();
388373 }
389374
390- // / Extract vector from rawElems based on kWidth and kBase
391- // / rawElems is a vector of kWidth elements. We need to prepare vector(s) of
392- // / kBase elements for each mfma instruction
393- SmallVector<Value> extractOperands (Value rawElems, int kWidth , int kBase ,
394- Type type, bool preserveBF16,
395- bool isConstantScale = false ) const {
375+ // / Process the elements in rawElems and prepare a vector for mfma input.
376+ // / rawElems is a vector of kBase elements. Each element is of the raw
377+ // / element type from the input. We need to prepare a vector of kBase
378+ // / elements of appropriate element type required by mfma instructions.
379+ Value prepareOperands (Value rawElems, int kBase , Type type, bool preserveBF16,
380+ bool isConstantScale = false ) const {
396381 auto b = TritonLLVMOpBuilder (loc, rewriter);
397- int kpack = kWidth / kBase ;
398- SmallVector<Value> results;
382+ Value results;
383+
384+ // Construct a vector type of kBase elements with desired type
399385 auto vecTy = vec_ty (type, kBase );
400386 if (type.isBF16 () && !preserveBF16)
401387 vecTy = vec_ty (i16_ty, kBase );
402- for ( int k = 0 ; k < kpack; ++k) {
403- Value vec = b. undef (vecTy);
404- for ( int elemId = 0 ; elemId < kBase ; ++elemId) {
405- auto val =
406- b. extract_element (type, rawElems, b. i32_val ( elemId + k * kBase ));
407- if (type. isBF16 () && !preserveBF16) {
408- // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
409- auto cast = b. bitcast (val, i16_ty);
410- vec = b.insert_element (vecTy, vec, cast, b. i32_val (elemId) );
411- } else {
412- vec = b. insert_element (vecTy, vec, val, b. i32_val (elemId));
413- }
388+ Value vec = b. undef (vecTy);
389+
390+ // For each element in rawElems, extract the element as the desired type,
391+ // bitcast it if needed, and insert it into vec.
392+ for ( int elemId = 0 ; elemId < kBase ; ++elemId) {
393+ auto val = b. extract_element (type, rawElems, b. i32_val (elemId));
394+ if (type. isBF16 () && !preserveBF16) {
395+ // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
396+ auto cast = b.bitcast (val, i16_ty );
397+ vec = b. insert_element (vecTy, vec, cast, b. i32_val (elemId));
398+ } else {
399+ vec = b. insert_element (vecTy, vec, val, b. i32_val (elemId));
414400 }
415- if (type.getIntOrFloatBitWidth () == 8 ) {
416- if (1 == kBase ) {
417- // This is only for the scale operands of scaled mfma on CDNA4
418- if (isConstantScale) {
419- // If the scale is constant(created by arith::ConstantOp), it will
420- // be put in a sgpr instead of vgpr. In that case, instead of
421- // vgpr[7:0], the instruction reads sgpr[30:23] as the scale value.
422- // So we need to manually left shift the scale by 23 bits to meet
423- // the requirement.
424- results.push_back (b.shl (
425- i32_ty, b.zext (i32_ty, b.bitcast (vec, i8_ty)), b.i32_val (23 )));
426- } else {
427- results.push_back (b.zext (i32_ty, b.bitcast (vec, i8_ty)));
428- }
401+ }
402+
403+ // Now we have a vector of kBase elements of desired type.
404+ // Then we need to prepare vec for results.
405+ if (type.getIntOrFloatBitWidth () == 8 ) {
406+ if (1 == kBase ) {
407+ // This is only for the scale operands of scaled mfma on CDNA4
408+ if (isConstantScale) {
409+ // If the scale is constant(created by arith::ConstantOp), it will
410+ // be put in a sgpr instead of vgpr. In that case, instead of
411+ // vgpr[7:0], the instruction reads sgpr[30:23] as the scale value.
412+ // So we need to manually left shift the scale by 23 bits to meet
413+ // the requirement.
414+ results = b.shl (i32_ty, b.zext (i32_ty, b.bitcast (vec, i8_ty)),
415+ b.i32_val (23 ));
416+ } else {
417+ results = b.zext (i32_ty, b.bitcast (vec, i8_ty));
429418 }
430- if (4 == kBase )
431- // This is for int8 on pre- CDNA3 GPUs
432- results.push_back (b.bitcast (vec, i32_ty));
433- if (8 == kBase )
434- results.push_back (b.bitcast (vec, i64_ty));
435- if (16 == kBase )
436- // This is only for the operands of scaled mfma on CDNA4
437- results.push_back (b.bitcast (vec, vec_ty (i32_ty, 4 )));
438- if (32 == kBase )
439- results.push_back (b.bitcast (vec, vec_ty (i32_ty, 8 )));
440- } else {
441- results.push_back (vec);
442419 }
420+ if (4 == kBase )
421+ // This is for int8 on pre- CDNA3 GPUs
422+ results = b.bitcast (vec, i32_ty);
423+ if (8 == kBase )
424+ results = b.bitcast (vec, i64_ty);
425+ if (16 == kBase )
426+ // This is only for the operands of scaled mfma on CDNA4
427+ results = b.bitcast (vec, vec_ty (i32_ty, 4 ));
428+ if (32 == kBase )
429+ results = b.bitcast (vec, vec_ty (i32_ty, 8 ));
430+ } else {
431+ results = vec;
443432 }
444433 return results;
445434 }
446435
447436 // / Converts dot operand structure to value table and converts types
448437 // / appropriate for mfma instructions
449- virtual SmallVector<ValueTable> getValuesFromDotOperandLayoutStruct (
450- Value value, int batch, int n0, int n1, int kWidth , int kBase , Type type,
451- bool allowXF32, bool preserveBF16, bool isConstantScale = false ) const {
438+ virtual ValueTable getValuesFromDotOperandLayoutStruct (
439+ Value value, int batch, int nonKRep, int kRepInKWidth , int kWidth ,
440+ int kBase , Type type, bool allowXF32, bool preserveBF16,
441+ bool isConstantScale = false ) const {
452442 auto tb = TritonLLVMOpBuilder (loc, rewriter);
453443 auto elems = unpackLLElements (loc, value, rewriter);
454- int kpack = kWidth / kBase ;
455- SmallVector<ValueTable> dotOpVals (kpack);
444+ // number of kBase-element vectors
445+ int numVecInKBase = kRepInKWidth * kWidth / kBase ;
446+ ValueTable dotOpVals;
447+
448+ SmallVector<int64_t > bounds = {batch, nonKRep, numVecInKBase, kBase };
449+ SmallVector<int64_t > strides = computeStrides (bounds);
456450 for (int b = 0 ; b < batch; ++b) {
457- for (int i = 0 ; i < n0; i++) {
458- for (int j = 0 ; j < n1; j++) {
451+ for (int nonK = 0 ; nonK < nonKRep; nonK++) {
452+ for (int kBaseVec = 0 ; kBaseVec < numVecInKBase; kBaseVec ++) {
453+ // For each kBase-element vector
454+
455+ // Step 1: construct each kBase-element vector by
456+ // - extracting kBase elements from elems and
457+ // - putting them into a kBase-element vector, i.e. rawElems
459458 Type elemTy = typeConverter->convertType (type);
460- Type ty = vec_ty (elemTy, kWidth );
459+ Type ty = vec_ty (elemTy, kBase );
461460 Value rawElems = tb.undef (ty);
462- for (int k = 0 ; k < kWidth ; ++k) {
463- rawElems = tb.insert_element (
464- ty, rawElems,
465- elems[kWidth * n1 * n0 * b + kWidth * n1 * i + kWidth * j + k],
466- tb.i32_val (k));
461+ for (int k = 0 ; k < kBase ; ++k) {
462+ auto index = linearize ({b, nonK, kBaseVec , k}, strides);
463+ rawElems =
464+ tb.insert_element (ty, rawElems, elems[index], tb.i32_val (k));
467465 }
468466
469- Value convertedElems;
467+ // Step 2: process rawElems based on element type
468+ // Note that for f32 input and XF32 is not allowed, nothing needs to
469+ // be done and rawElems is inserted into the ValueTable directly
470470 if (type.isF32 () && !allowXF32) {
471- for (int k = 0 ; k < kpack; ++k)
472- dotOpVals[k][{b, i, j}] =
473- tb.extract_element (type, rawElems, tb.i32_val (k));
471+ dotOpVals[{b, nonK, kBaseVec }] =
472+ tb.extract_element (type, rawElems, tb.i32_val (0 ));
474473 } else {
475- SmallVector< Value> vals;
474+ Value vals;
476475 if (type.isF32 () && allowXF32) {
477- vals = extractOperands (rawElems, kWidth , kBase , f32_ty,
478- preserveBF16);
476+ vals = prepareOperands (rawElems, kBase , f32_ty, preserveBF16);
479477 } else if (type.getIntOrFloatBitWidth () == 8 ) {
480- vals = extractOperands (rawElems, kWidth , kBase , i8_ty,
481- preserveBF16, isConstantScale);
478+ vals = prepareOperands (rawElems, kBase , i8_ty, preserveBF16 ,
479+ isConstantScale);
482480 } else if (type.isBF16 ()) {
483- vals = extractOperands (rawElems, kWidth , kBase , bf16_ty,
484- preserveBF16);
481+ vals = prepareOperands (rawElems, kBase , bf16_ty, preserveBF16);
485482 } else {
486483 assert (type.isF16 () && " Unsupported data type" );
487- vals = extractOperands (rawElems, kWidth , kBase , f16_ty,
488- preserveBF16);
489- }
490- for (int k = 0 ; k < kpack; ++k) {
491- dotOpVals[k][{b, i, j}] = vals[k];
484+ vals = prepareOperands (rawElems, kBase , f16_ty, preserveBF16);
492485 }
486+
487+ // Step 3: Insert the processed vals into the ValueTable
488+ dotOpVals[{b, nonK, kBaseVec }] = vals;
493489 }
494490 }
495491 }
@@ -638,8 +634,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
638634
639635 // Scales have the same replica distributions as their corresponding
640636 // operands.
641- SmallVector< ValueTable> operandAScale;
642- SmallVector< ValueTable> operandBScale;
637+ ValueTable operandAScale;
638+ ValueTable operandBScale;
643639 if (existBothScales) {
644640 auto aScaleTensorTy = cast<RankedTensorType>(aScale.getType ());
645641 operandAScale = getValuesFromDotOperandLayoutStruct (
@@ -663,6 +659,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
663659 const int subBlocks =
664660 getNumSubmatrices (aTensorTy.getElementType (), mDim , nDim);
665661 auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
662+ int numVecInKBase = numRepK * aKWidth / aKBase;
666663
667664 Value firstMfma;
668665 auto tb = TritonLLVMOpBuilder (loc, rewriter);
@@ -679,44 +676,36 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
679676 tb.i32_val (v));
680677 }
681678 acc = zeroAuxiliarBlocks (subBlocks, acc);
682- for (int k = 0 ; k < numRepK; k++) {
683- for (int kPack = 0 ; kPack < aKWidth / aKBase; ++kPack ) {
684- if (existBothScales) {
685- if (mfmaLayout.getIsTransposed ()) {
686- acc = generateScaledMFMAOp (intrinsicName,
687- operandB[kPack ][{b, n, k}],
688- operandA[kPack ][{b, m, k}], acc,
689- operandBScale[kPack ][{b, n, k}],
690- operandAScale[kPack ][{b, m, k}],
691- maybeMfmaIntrinsic->bElementType ,
692- maybeMfmaIntrinsic->aElementType );
693- } else {
694- acc = generateScaledMFMAOp (intrinsicName,
695- operandA[kPack ][{b, m, k}],
696- operandB[kPack ][{b, n, k}], acc,
697- operandAScale[kPack ][{b, m, k}],
698- operandBScale[kPack ][{b, n, k}],
699- maybeMfmaIntrinsic->aElementType ,
700- maybeMfmaIntrinsic->bElementType );
701- }
679+ for (int k = 0 ; k < numVecInKBase; k++) {
680+ if (existBothScales) {
681+ if (mfmaLayout.getIsTransposed ()) {
682+ acc = generateScaledMFMAOp (
683+ intrinsicName, operandB[{b, n, k}], operandA[{b, m, k}],
684+ acc, operandBScale[{b, n, k}], operandAScale[{b, m, k}],
685+ maybeMfmaIntrinsic->bElementType ,
686+ maybeMfmaIntrinsic->aElementType );
687+ } else {
688+ acc = generateScaledMFMAOp (
689+ intrinsicName, operandA[{b, m, k}], operandB[{b, n, k}],
690+ acc, operandAScale[{b, m, k}], operandBScale[{b, n, k}],
691+ maybeMfmaIntrinsic->aElementType ,
692+ maybeMfmaIntrinsic->bElementType );
693+ }
694+ } else {
695+ if (mfmaLayout.getIsTransposed ()) {
696+ acc = generateScaledMFMAOp (intrinsicName, operandB[{b, n, k}],
697+ operandA[{b, m, k}], acc,
698+ maybeMfmaIntrinsic->bElementType ,
699+ maybeMfmaIntrinsic->aElementType );
702700 } else {
703- if (mfmaLayout.getIsTransposed ()) {
704- acc = generateScaledMFMAOp (intrinsicName,
705- operandB[kPack ][{b, n, k}],
706- operandA[kPack ][{b, m, k}], acc,
707- maybeMfmaIntrinsic->bElementType ,
708- maybeMfmaIntrinsic->aElementType );
709- } else {
710- acc = generateScaledMFMAOp (intrinsicName,
711- operandA[kPack ][{b, m, k}],
712- operandB[kPack ][{b, n, k}], acc,
713- maybeMfmaIntrinsic->aElementType ,
714- maybeMfmaIntrinsic->bElementType );
715- }
701+ acc = generateScaledMFMAOp (intrinsicName, operandA[{b, m, k}],
702+ operandB[{b, n, k}], acc,
703+ maybeMfmaIntrinsic->aElementType ,
704+ maybeMfmaIntrinsic->bElementType );
716705 }
717- if (!firstMfma)
718- firstMfma = acc;
719706 }
707+ if (!firstMfma)
708+ firstMfma = acc;
720709 }
721710 acc = reduceSubBlocks (subBlocks, acc);
722711 adjustAccForSmallKDim (fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
0 commit comments