@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001100 }
11011101};
11021102
1103+ struct TransposeLoadOpLowering
1104+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1105+ TransposeLoadOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
1106+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1107+
1108+ Chipset chipset;
1109+
1110+ LogicalResult
1111+ matchAndRewrite (TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1112+ ConversionPatternRewriter &rewriter) const override {
1113+ if (chipset != kGfx950 )
1114+ return op.emitOpError (" Non-gfx950 chipset not supported" );
1115+
1116+ Location loc = op.getLoc ();
1117+ auto srcMemRefType = cast<MemRefType>(op.getSrc ().getType ());
1118+
1119+ // Elements in subbyte memrefs are stored non-contiguously,
1120+ // reject if source is sub-byte memref. Use emulated memrefs instead.
1121+ size_t srcElementSize =
1122+ srcMemRefType.getElementType ().getIntOrFloatBitWidth ();
1123+ if (srcElementSize < 8 )
1124+ return op.emitOpError (" Expect source memref to have at least 8 bits "
1125+ " element size, got " )
1126+ << srcElementSize;
1127+
1128+ auto resultType = cast<VectorType>(op.getResult ().getType ());
1129+ Value srcPtr =
1130+ getStridedElementPtr (rewriter, loc, srcMemRefType, adaptor.getSrc (),
1131+ (adaptor.getSrcIndices ()));
1132+
1133+ size_t numElements = resultType.getNumElements ();
1134+ size_t elementTypeSize =
1135+ resultType.getElementType ().getIntOrFloatBitWidth ();
1136+
1137+ // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1138+ // the element size is smaller than 16 bits.
1139+ Type rocdlResultType = VectorType::get ((numElements * elementTypeSize) / 32 ,
1140+ rewriter.getIntegerType (32 ));
1141+ Type llvmResultType = typeConverter->convertType (resultType);
1142+
1143+ switch (elementTypeSize) {
1144+ case 4 : {
1145+ assert (numElements == 16 );
1146+ auto rocdlOp =
1147+ rewriter.create <ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
1148+ rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1149+ break ;
1150+ }
1151+ case 6 : {
1152+ assert (numElements == 16 );
1153+ auto rocdlOp =
1154+ rewriter.create <ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
1155+ rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1156+ break ;
1157+ }
1158+ case 8 : {
1159+ assert (numElements == 8 );
1160+ auto rocdlOp =
1161+ rewriter.create <ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
1162+ rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1163+ break ;
1164+ }
1165+ case 16 : {
1166+ assert (numElements == 4 );
1167+ rewriter.replaceOpWithNewOp <ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1168+ srcPtr);
1169+ break ;
1170+ }
1171+ default :
1172+ return op.emitOpError (" Unsupported element size for transpose load" );
1173+ }
1174+ return success ();
1175+ }
1176+ };
1177+
11031178struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern <GatherToLDSOp> {
11041179 GatherToLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
11051180 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491824 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501825 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511826 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter ,
1753- chipset);
1827+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1828+ TransposeLoadOpLowering>(converter, chipset);
17541829 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
17551830}
0 commit comments