@@ -514,8 +514,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
514514// / Push an input operand. If it is a float type, nothing to do. If it is
515515// / an integer type, then we need to also push its signdness (1 for signed, 0
516516// / for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
517- // / vector. We also need to convert bfloat inputs to i16 to account for the lack
518- // / of bfloat support in the WMMA intrinsics themselves.
517+ // / vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
518+ // / We also need to convert bfloat inputs to i16 to account for the bfloat
519+ // / intrinsics having been defined before the AMD backend supported bfloat. We
520+ // / similarly need to pack 8-bit float types into integers as if they were i8
521+ // / (which they are for the backend's purposes).
519522static void wmmaPushInputOperand (ConversionPatternRewriter &rewriter,
520523 Location loc,
521524 const TypeConverter *typeConverter,
@@ -524,12 +527,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
524527 SmallVector<Value, 4 > &operands) {
525528 Type inputType = llvmInput.getType ();
526529 auto vectorType = dyn_cast<VectorType>(inputType);
530+ if (!vectorType) {
531+ operands.push_back (llvmInput);
532+ return ;
533+ }
527534 Type elemType = vectorType.getElementType ();
528535
529536 if (elemType.isBF16 ())
530537 llvmInput = rewriter.create <LLVM::BitcastOp>(
531538 loc, vectorType.clone (rewriter.getI16Type ()), llvmInput);
532- if (! elemType.isInteger ( 8 ) ) {
539+ if (elemType.getIntOrFloatBitWidth () > 8 ) {
533540 operands.push_back (llvmInput);
534541 return ;
535542 }
@@ -538,34 +545,43 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
538545 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
539546 // fp8/int8 information is lost during the conversion process.
540547 auto mlirInputType = cast<VectorType>(mlirInput.getType ());
541- bool isInputInt8 = mlirInputType.getElementType ().isInteger (8 );
542- if (isInputInt8 ) {
548+ bool isInputInteger = mlirInputType.getElementType ().isInteger ();
549+ if (isInputInteger ) {
543550 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
544551 bool localIsUnsigned = isUnsigned;
545- if (elemType.isUnsignedInteger (8 )) {
552+ if (elemType.isUnsignedInteger ()) {
546553 localIsUnsigned = true ;
547- } else if (elemType.isSignedInteger (8 )) {
554+ } else if (elemType.isSignedInteger ()) {
548555 localIsUnsigned = false ;
549556 }
550557 Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
551558 operands.push_back (sign);
552559 }
553560
554- int64_t numBytes = vectorType.getNumElements ();
561+ int64_t numBits =
562+ vectorType.getNumElements () * elemType.getIntOrFloatBitWidth ();
555563 Type i32 = rewriter.getI32Type ();
556- VectorType vectorType32bits = VectorType::get (numBytes * 8 / 32 , i32 );
557- auto llvmVectorType32bits = typeConverter->convertType (vectorType32bits);
558- Value result = rewriter.createOrFold <LLVM::BitcastOp>(
559- loc, llvmVectorType32bits, llvmInput);
560- operands.push_back (result);
564+ Type intrinsicInType = numBits <= 32
565+ ? (Type)rewriter.getIntegerType (numBits)
566+ : (Type)VectorType::get (numBits / 32 , i32 );
567+ auto llvmIntrinsicInType = typeConverter->convertType (intrinsicInType);
568+ Value castInput = rewriter.createOrFold <LLVM::BitcastOp>(
569+ loc, llvmIntrinsicInType, llvmInput);
570+ // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
571+ // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
572+ // Add in the zeros here.
573+ if (numBits < 32 )
574+ castInput = rewriter.create <LLVM::ZExtOp>(loc, i32 , castInput);
575+ operands.push_back (castInput);
561576}
562577
563578// / Push the output operand. For many cases this is only pushing the output in
564579// / the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
565580// / since the same numbers of VGPRs is used, we need to decide if to store the
566581// / result in the upper 16 bits of the VGPRs or in the lower part. To store the
567582// / result in the lower 16 bits, set subwordOffset to 1, otherwise result will
568- // / be stored it in the upper part
583+ // / be stored it in the upper part. The subwordOffset must not be set for gfx12,
584+ // / as the instructions have been changed to return fewer registers instead.
569585static void wmmaPushOutputOperand (ConversionPatternRewriter &rewriter,
570586 Location loc,
571587 const TypeConverter *typeConverter,
@@ -728,8 +744,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
728744static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
729745 Chipset chipset) {
730746 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA ().getType ());
747+ auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB ().getType ());
731748 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC ().getType ());
732749 auto elemSourceType = sourceVectorType.getElementType ();
750+ auto elemBSourceType = sourceBVectorType.getElementType ();
733751 auto elemDestType = destVectorType.getElementType ();
734752
735753 if (elemSourceType.isF16 () && elemDestType.isF32 ())
@@ -742,10 +760,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
742760 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
743761 if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
744762 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
745- if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32 ())
746- return ROCDL::wmma_f32_16x16x16_fp8::getOperationName ();
747- if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32 ())
748- return ROCDL::wmma_f32_16x16x16_bf8::getOperationName ();
763+ if (chipset.majorVersion == 11 ) {
764+ if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
765+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
766+ }
767+ if (chipset.majorVersion >= 12 ) {
768+ if (isa<Float8E4M3FNType>(elemSourceType) &&
769+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
770+ return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName ();
771+ if (isa<Float8E4M3FNType>(elemSourceType) &&
772+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
773+ return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName ();
774+ if (isa<Float8E5M2Type>(elemSourceType) &&
775+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
776+ return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName ();
777+ if (isa<Float8E5M2Type>(elemSourceType) &&
778+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
779+ return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName ();
780+ if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 )) {
781+ bool isWave64 = destVectorType.getNumElements () == 4 ;
782+ // This is the ambiguous case. 8 inputs to the wave64 version means that
783+ // we want the 16x16x32 version, but for wave32 they mean the short form.
784+ bool has8Inputs = sourceVectorType.getNumElements () == 8 ;
785+ if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
786+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName ();
787+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
788+ }
789+ }
749790 return std::nullopt ;
750791}
751792
@@ -823,6 +864,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
823864 if (!maybeIntrinsic.has_value ())
824865 return op.emitOpError (" no intrinsic matching WMMA on the given chipset" );
825866
867+ if (chipset.majorVersion >= 12 && op.getSubwordOffset () != 0 )
868+ return op.emitOpError (" subwordOffset not supported on gfx12+" );
869+
826870 OperationState loweredOp (loc, *maybeIntrinsic);
827871 loweredOp.addTypes (rawOutType);
828872
0 commit comments