@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
499499// / and LLVM AMDGPU intrinsics convention.
500500// /
501501// / Specifically:
502- // / 1. If the element type is bfloat16, bitcast it to i16.
502+ // / 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503+ // / allows bf16. Newer MFMAs support bf16 types on operand, check
504+ // / IntrinsicsAMDGPU.td file for reference.
503505// / 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
504506// / instead, which is what the f8f6f4 intrinsics use.
505507// / 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
509511// / therefore 8-bit and smaller floats are represented as their corresponding
510512// / `iN` integers.
511513static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
512- Location loc, Value input) {
514+ Location loc, Value input,
515+ bool allowBf16 = true ) {
513516 Type inputType = input.getType ();
514517 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
515- if (vectorType.getElementType ().isBF16 ())
518+ if (vectorType.getElementType ().isBF16 () && !allowBf16 )
516519 return rewriter.create <LLVM::BitcastOp>(
517520 loc, vectorType.clone (rewriter.getI16Type ()), input);
518521 if (vectorType.getElementType ().isInteger (8 ) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
958961
959962 StringRef intrinsicName =
960963 isScaled ? std::get<0 >(*maybeScaledIntrinsic) : *maybeIntrinsic;
964+ // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966+ bool allowBf16 = [&]() {
967+ if (chipset < kGfx950 )
968+ return false ;
969+ if (isScaled)
970+ return true ;
971+ return intrinsicName.contains (" 16x16x32.bf16" ) ||
972+ intrinsicName.contains (" 32x32x16.bf16" );
973+ }();
961974 OperationState loweredOp (loc, intrinsicName);
962975 loweredOp.addTypes (intrinsicOutType);
963- loweredOp.addOperands (
964- {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
965- convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
966- adaptor.getDestC ()});
976+ loweredOp.addOperands ({convertMFMAVectorOperand (
977+ rewriter, loc, adaptor.getSourceA (), allowBf16),
978+ convertMFMAVectorOperand (
979+ rewriter, loc, adaptor.getSourceB (), allowBf16),
980+ adaptor.getDestC ()});
967981 if (isScaled) {
968982 Value zero = createI32Constant (rewriter, loc, 0 );
969983 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
0 commit comments