@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
903903 }
904904};
905905
906+ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern <GlobalLoadLDSOp> {
907+ GlobalLoadLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
908+ : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
909+
910+ Chipset chipset;
911+
912+ LogicalResult
913+ matchAndRewrite (GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
914+ ConversionPatternRewriter &rewriter) const override {
915+ Location loc = op.getLoc ();
916+
917+ auto elemType = cast<MemRefType>(op.getDst ().getType ()).getElementType ();
918+ size_t elemSizeInBits = elemType.getIntOrFloatBitWidth ();
919+ if (elemSizeInBits % 8 != 0 )
920+ return op.emitOpError (" element size must be a multiple of 8" );
921+ auto loadWidth = elemSizeInBits / 8 ;
922+
923+ // TODO: add chipset support check
924+ if (chipset.majorVersion >= 12 )
925+ return op.emitOpError (" TODO" );
926+
927+ // TODO: fold this into chipset check.
928+ // Currently only 1, 2, and 4 byte loads are supported.
929+ if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4 ))
930+ return op.emitOpError (" unsupported element size" );
931+
932+ Value src = adaptor.getSrc ();
933+ Value dst = adaptor.getDst ();
934+ Value memrefSrc = op.getSrc ();
935+ Value memrefDst = op.getDst ();
936+
937+ // Collapse src memref with indices:
938+ auto flattenIndex = [&](Value memref, MemRefType memrefType,
939+ ValueRange indices) -> std::optional<Value> {
940+ MemRefDescriptor memRefDescriptor (memref);
941+ int64_t offset = 0 ;
942+ SmallVector<int64_t , 5 > strides;
943+ if (failed (memrefType.getStridesAndOffset (strides, offset)))
944+ return {};
945+ return getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices,
946+ strides);
947+ };
948+
949+ // Source
950+ auto optSrcIdx = flattenIndex (src, cast<MemRefType>(memrefSrc.getType ()),
951+ op.getSrcIndices ());
952+ if (!optSrcIdx)
953+ return op.emitOpError (" failed to flatten source memref indices" );
954+ auto optDstIdx = flattenIndex (dst, cast<MemRefType>(memrefDst.getType ()),
955+ op.getDstIndices ());
956+ if (!optDstIdx)
957+ return op.emitOpError (" failed to flatten destination memref indices" );
958+
959+ Type srcPtrType =
960+ LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
961+ Type dstPtrType =
962+ LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
963+ Value srcPtr = rewriter.create <LLVM::GEPOp>(
964+ loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
965+
966+ Value dstPtr = rewriter.create <LLVM::GEPOp>(
967+ loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
968+
969+ rewriter.replaceOpWithNewOp <ROCDL::GlobalLoadLDSOp>(
970+ op, srcPtr, dstPtr, createI32Constant (rewriter, loc, loadWidth),
971+ createI32Constant (rewriter, loc, 0 ),
972+ createI32Constant (rewriter, loc, 0 ));
973+
974+ return success ();
975+ }
976+ };
977+
906978namespace {
907979struct ExtPackedFp8OpLowering final
908980 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
12861358 ROCDL::RawPtrBufferAtomicCmpSwap>,
12871359 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
12881360 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1289- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter ,
1290- chipset);
1361+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1362+ GlobalLoadLDSOpLowering>(converter, chipset);
12911363}
0 commit comments