@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001100 }
11011101};
11021102
1103+ struct TransposeLoadOpLowering
1104+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1105+ TransposeLoadOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
1106+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1107+
1108+ Chipset chipset;
1109+
1110+ LogicalResult
1111+ matchAndRewrite (TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1112+ ConversionPatternRewriter &rewriter) const override {
1113+ if (chipset < kGfx950 )
1114+ return op.emitOpError (" Non-gfx950 chipset not supported" );
1115+
1116+ Location loc = op.getLoc ();
1117+ auto srcMemRefType = cast<MemRefType>(op.getSrc ().getType ());
1118+ Value srcPtr =
1119+ getStridedElementPtr (rewriter, loc, srcMemRefType, adaptor.getSrc (),
1120+ (adaptor.getSrcIndices ()));
1121+ auto elementTypeSize = cast<VectorType>(op.getDst ().getType ())
1122+ .getElementType ()
1123+ .getIntOrFloatBitWidth ();
1124+
1125+ // TODO: support ds_read_tr16_b64 intrinsic.
1126+ switch (elementTypeSize) {
1127+ case 4 :
1128+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr4_b64>(
1129+ op, op.getDst ().getType (), srcPtr);
1130+ break ;
1131+ case 8 :
1132+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr8_b64>(
1133+ op, op.getDst ().getType (), srcPtr);
1134+ break ;
1135+ case 16 :
1136+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr16_b64>(
1137+ op, op.getDst ().getType (), srcPtr);
1138+ break ;
1139+ default :
1140+ return op.emitOpError (" Unsupported element size for transpose load" );
1141+ }
1142+ return success ();
1143+ }
1144+ };
1145+
11031146struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern <GatherToLDSOp> {
11041147 GatherToLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
11051148 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491792 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501793 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511794 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter ,
1753- chipset);
1795+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1796+ TransposeLoadOpLowering>(converter, chipset);
17541797 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
17551798}
0 commit comments