@@ -271,11 +271,13 @@ struct DotOpMFMAConversionHelper {
271271 auto elemTyA = aTensorTy.getElementType ();
272272 auto elemTyB = bTensorTy.getElementType ();
273273
274+ const auto kDimOperandSize = aTensorTy.getShape ().back ();
275+
274276 bool allowXF32 =
275277 op.getInputPrecision () == InputPrecision::TF32 && mfmaVersion == 3 ;
276278 StringRef mfmaInsnName;
277- auto maybeMfmaInsn = MfmaInsn::selectMfma (mDim , nDim, elemTyA, elemTyB,
278- mfmaVersion, allowXF32);
279+ auto maybeMfmaInsn = MfmaInsn::selectMfma (
280+ mDim , nDim, kDimOperandSize , elemTyA, elemTyB, mfmaVersion, allowXF32);
279281 if (failed (maybeMfmaInsn))
280282 llvm::report_fatal_error (" No match found in MFMA database\n " );
281283
@@ -290,8 +292,6 @@ struct DotOpMFMAConversionHelper {
290292 if (aTensorTy.getElementType ().isF32 () && allowXF32)
291293 kWidth *= 2 ;
292294
293- auto rank = aTensorTy.getShape ().size ();
294- const auto kDimOperandSize = aTensorTy.getShape ()[rank - 1 ];
295295 const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand (kWidth , 0 )[1 ];
296296
297297 auto repA = mfmaLayout.getRepForOperand (aTensorTy.getShape (), kWidth , 0 );
@@ -309,12 +309,13 @@ struct DotOpMFMAConversionHelper {
309309 auto numRepB = repA[0 ];
310310 assert (repA[0 ] == repB[0 ]);
311311
312+ bool preserveBF16 = mfmaInsnName.contains (" .bf16" ) && mfmaVersion >= 4 ;
312313 auto operandA = getValuesFromDotOperandLayoutStruct (
313314 loadedA, numRepB, numRepM, numRepK, kWidth , kBase ,
314- aTensorTy.getElementType (), allowXF32);
315+ aTensorTy.getElementType (), allowXF32, preserveBF16 );
315316 auto operandB = getValuesFromDotOperandLayoutStruct (
316317 loadedB, numRepB, numRepN, numRepK, kWidth , kBase ,
317- aTensorTy.getElementType (), allowXF32);
318+ aTensorTy.getElementType (), allowXF32, preserveBF16 );
318319
319320 auto dstElemTy = dTensorTy.getElementType ();
320321 auto fc = unpackLLElements (loc, loadedC, rewriter);
@@ -379,19 +380,19 @@ struct DotOpMFMAConversionHelper {
379380 // / rawElems is a vector of kWidth elements. We need to prepare vector(s) of
380381 // / kBase elements for each mfma instruction
381382 SmallVector<Value> extractOperands (Value rawElems, int kWidth , int kBase ,
382- Type type) const {
383+ Type type, bool preserveBF16 ) const {
383384 auto b = TritonLLVMOpBuilder (loc, rewriter);
384385 int kpack = kWidth / kBase ;
385386 SmallVector<Value> results;
386387 auto vecTy = vec_ty (type, kBase );
387- if (type.isBF16 ())
388+ if (type.isBF16 () && !preserveBF16 )
388389 vecTy = vec_ty (i16_ty, kBase );
389390 for (int k = 0 ; k < kpack; ++k) {
390391 Value vec = b.undef (vecTy);
391392 for (int elemId = 0 ; elemId < kBase ; ++elemId) {
392393 auto val =
393394 b.extract_element (type, rawElems, b.i32_val (elemId + k * kBase ));
394- if (type.isBF16 ()) {
395+ if (type.isBF16 () && !preserveBF16 ) {
395396 // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
396397 auto cast = b.bitcast (val, i16_ty);
397398 vec = b.insert_element (vecTy, vec, cast, b.i32_val (elemId));
@@ -423,7 +424,7 @@ struct DotOpMFMAConversionHelper {
423424 virtual SmallVector<ValueTable>
424425 getValuesFromDotOperandLayoutStruct (Value value, int batch, int n0, int n1,
425426 int kWidth , int kBase , Type type,
426- bool allowXF32) const {
427+ bool allowXF32, bool preserveBF16 ) const {
427428 auto tb = TritonLLVMOpBuilder (loc, rewriter);
428429 auto elems = unpackLLElements (loc, value, rewriter);
429430 int kpack = kWidth / kBase ;
@@ -449,14 +450,18 @@ struct DotOpMFMAConversionHelper {
449450 } else {
450451 SmallVector<Value> vals;
451452 if (type.isF32 () && allowXF32) {
452- vals = extractOperands (rawElems, kWidth , kBase , f32_ty);
453+ vals = extractOperands (rawElems, kWidth , kBase , f32_ty,
454+ preserveBF16);
453455 } else if (type.getIntOrFloatBitWidth () == 8 ) {
454- vals = extractOperands (rawElems, kWidth , kBase , i8_ty);
456+ vals =
457+ extractOperands (rawElems, kWidth , kBase , i8_ty, preserveBF16);
455458 } else if (type.isBF16 ()) {
456- vals = extractOperands (rawElems, kWidth , kBase , bf16_ty);
459+ vals = extractOperands (rawElems, kWidth , kBase , bf16_ty,
460+ preserveBF16);
457461 } else {
458462 assert (type.isF16 () && " Unsupported data type" );
459- vals = extractOperands (rawElems, kWidth , kBase , f16_ty);
463+ vals = extractOperands (rawElems, kWidth , kBase , f16_ty,
464+ preserveBF16);
460465 }
461466 for (int k = 0 ; k < kpack; ++k) {
462467 dotOpVals[k][{b, i, j}] = vals[k];
@@ -518,6 +523,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
518523 ScaleDotElemType aElemType = op.getLhsType ();
519524 ScaleDotElemType bElemType = op.getRhsType ();
520525
526+ const auto kDimOperandSize = aTensorTy.getShape ().back ();
527+
521528 auto supportsTypes = [](ScaleDotElemType elemType) {
522529 return elemType == ScaleDotElemType::E2M1;
523530 };
@@ -529,7 +536,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
529536 auto ctx = op.getContext ();
530537 constexpr bool allowXF32 = false ;
531538 auto maybeMfmaInsn = MfmaInsn::selectMfma (
532- mDim , nDim, scaleDotElemTypeToMLIRType (ctx, aElemType),
539+ mDim , nDim, kDimOperandSize , scaleDotElemTypeToMLIRType (ctx, aElemType),
533540 scaleDotElemTypeToMLIRType (ctx, bElemType), mfmaVersion, allowXF32);
534541 if (failed (maybeMfmaInsn))
535542 llvm::report_fatal_error (" No match found in MFMA database\n " );
@@ -544,8 +551,6 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
544551 auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding ());
545552 auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding ());
546553 int kWidth = aEncoding.getKWidth ();
547- auto rank = aTensorTy.getShape ().size ();
548- const auto kDimOperandSize = aTensorTy.getShape ()[rank - 1 ];
549554 const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand (kWidth , 0 )[1 ];
550555
551556 auto repA = mfmaLayout.getRepForOperand (aTensorTy.getShape (), kWidth , 0 );
@@ -575,19 +580,19 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
575580
576581 auto operandA = getValuesFromDotOperandLayoutStruct (
577582 loadedA, numRepB, numRepM, numRepK, kWidth , kBase ,
578- aTensorTy.getElementType (), allowXF32);
583+ aTensorTy.getElementType (), allowXF32, /* preserveBF16= */ false );
579584 auto operandB = getValuesFromDotOperandLayoutStruct (
580585 loadedB, numRepB, numRepN, numRepK, kWidth , kBase ,
581- bTensorTy.getElementType (), allowXF32);
586+ bTensorTy.getElementType (), allowXF32, /* preserveBF16= */ false );
582587
583588 // Scales have the same replica distributions as their corresponding
584589 // operands.
585590 auto operandAScale = getValuesFromDotOperandLayoutStruct (
586591 loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase,
587- aScaleTensorTy.getElementType (), allowXF32);
592+ aScaleTensorTy.getElementType (), allowXF32, /* preserveBF16= */ false );
588593 auto operandBScale = getValuesFromDotOperandLayoutStruct (
589594 loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase,
590- bScaleTensorTy.getElementType (), allowXF32);
595+ bScaleTensorTy.getElementType (), allowXF32, /* preserveBF16= */ false );
591596
592597 auto dstElemTy = dTensorTy.getElementType ();
593598 auto fc = unpackLLElements (loc, loadedC, rewriter);
0 commit comments