@@ -499,9 +499,7 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
499499// / and LLVM AMDGPU intrinsics convention.
500500// /
501501// / Specifically:
502- // / 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503- // / allows bf16. Newer MFMAs support bf16 types on operand, check
504- // / IntrinsicsAMDGPU.td file for reference.
502+ // / 1. If the element type is bfloat16, bitcast it to i16.
505503// / 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
506504// / instead, which is what the f8f6f4 intrinsics use.
507505// / 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -511,11 +509,10 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
511509// / therefore 8-bit and smaller floats are represented as their corresponding
512510// / `iN` integers.
513511static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
514- Location loc, Value input,
515- bool allowBf16 = true ) {
512+ Location loc, Value input) {
516513 Type inputType = input.getType ();
517514 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
518- if (vectorType.getElementType ().isBF16 () && !allowBf16 )
515+ if (vectorType.getElementType ().isBF16 ())
519516 return rewriter.create <LLVM::BitcastOp>(
520517 loc, vectorType.clone (rewriter.getI16Type ()), input);
521518 if (vectorType.getElementType ().isInteger (8 ) &&
@@ -961,23 +958,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
961958
962959 StringRef intrinsicName =
963960 isScaled ? std::get<0 >(*maybeScaledIntrinsic) : *maybeIntrinsic;
964- // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965- // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966- bool allowBf16 = [&]() {
967- if (chipset < kGfx950 )
968- return false ;
969- if (isScaled)
970- return true ;
971- return intrinsicName.contains (" 16x16x32.bf16" ) ||
972- intrinsicName.contains (" 32x32x16.bf16" );
973- }();
974961 OperationState loweredOp (loc, intrinsicName);
975962 loweredOp.addTypes (intrinsicOutType);
976- loweredOp.addOperands ({convertMFMAVectorOperand (
977- rewriter, loc, adaptor.getSourceA (), allowBf16),
978- convertMFMAVectorOperand (
979- rewriter, loc, adaptor.getSourceB (), allowBf16),
980- adaptor.getDestC ()});
963+ loweredOp.addOperands (
964+ {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
965+ convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
966+ adaptor.getDestC ()});
981967 if (isScaled) {
982968 Value zero = createI32Constant (rewriter, loc, 0 );
983969 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1100,6 +1086,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001086 }
11011087};
11021088
1089+ struct TransposeLoadOpLowering
1090+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1091+ TransposeLoadOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
1092+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1093+
1094+ Chipset chipset;
1095+
1096+ LogicalResult
1097+ matchAndRewrite (TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1098+ ConversionPatternRewriter &rewriter) const override {
1099+ if (chipset < kGfx950 )
1100+ return op.emitOpError (" Non-gfx950 chipset not supported" );
1101+
1102+ Location loc = op.getLoc ();
1103+ auto srcMemRefType = cast<MemRefType>(op.getSrc ().getType ());
1104+ Value srcPtr =
1105+ getStridedElementPtr (rewriter, loc, srcMemRefType, adaptor.getSrc (),
1106+ (adaptor.getSrcIndices ()));
1107+ auto elementTypeSize = cast<VectorType>(op.getDst ().getType ())
1108+ .getElementType ()
1109+ .getIntOrFloatBitWidth ();
1110+
1111+ // TODO: support ds_read_tr16_b64 intrinsic.
1112+ switch (elementTypeSize) {
1113+ case 4 :
1114+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr4_b64>(
1115+ op, op.getDst ().getType (), srcPtr);
1116+ break ;
1117+ case 8 :
1118+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr8_b64>(
1119+ op, op.getDst ().getType (), srcPtr);
1120+ break ;
1121+ case 16 :
1122+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr16_b64>(
1123+ op, op.getDst ().getType (), srcPtr);
1124+ break ;
1125+ default :
1126+ return op.emitOpError (" Unsupported element size for transpose load" );
1127+ }
1128+ return success ();
1129+ }
1130+ };
1131+
11031132struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern <GatherToLDSOp> {
11041133 GatherToLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
11051134 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1778,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491778 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501779 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511780 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter ,
1753- chipset);
1781+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1782+ TransposeLoadOpLowering>(converter, chipset);
17541783 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
17551784}
0 commit comments