@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
903903 }
904904};
905905
906- struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern <GlobalLoadLDSOp> {
906+ struct GlobalLoadLDSOpLowering
907+ : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
907908 GlobalLoadLDSOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
908909 : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
909910
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
918919 size_t elemSizeInBits = elemType.getIntOrFloatBitWidth ();
919920 if (elemSizeInBits % 8 != 0 )
920921 return op.emitOpError (" element size must be a multiple of 8" );
922+
923+ // TODO: instead of only transfering one element per thread, we could
924+ // augment it to transfer multiple elements per thread by issuing multiple
925+ // `global_load_lds` instructions.
921926 auto loadWidth = elemSizeInBits / 8 ;
922927
923928 // TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
934939 Value memrefSrc = op.getSrc ();
935940 Value memrefDst = op.getDst ();
936941
937- // Collapse src memref with indices:
938- auto flattenIndex = [&](Value memref, MemRefType memrefType,
939- ValueRange indices) -> std::optional<Value> {
942+ // Collapse src memref with indices, returns the base pointer and linearized
943+ // index.
944+ auto flattenIndex =
945+ [&](Value memref, MemRefType memrefType,
946+ ValueRange indices) -> std::optional<std::pair<Value, Value>> {
940947 MemRefDescriptor memRefDescriptor (memref);
941948 int64_t offset = 0 ;
942949 SmallVector<int64_t , 5 > strides;
943950 if (failed (memrefType.getStridesAndOffset (strides, offset)))
944951 return {};
945- return getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices,
946- strides);
952+ return std::make_pair (
953+ memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (),
954+ memrefType),
955+ getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices, strides));
947956 };
948957
949958 // Source
950- auto optSrcIdx = flattenIndex (src, cast<MemRefType>(memrefSrc.getType ()),
951- op.getSrcIndices ());
952- if (!optSrcIdx )
959+ auto optSrcBuffer = flattenIndex (src, cast<MemRefType>(memrefSrc.getType ()),
960+ op.getSrcIndices ());
961+ if (!optSrcBuffer )
953962 return op.emitOpError (" failed to flatten source memref indices" );
954- auto optDstIdx = flattenIndex (dst, cast<MemRefType>(memrefDst.getType ()),
955- op.getDstIndices ());
956- if (!optDstIdx )
963+ auto optDstBuffer = flattenIndex (dst, cast<MemRefType>(memrefDst.getType ()),
964+ op.getDstIndices ());
965+ if (!optDstBuffer )
957966 return op.emitOpError (" failed to flatten destination memref indices" );
958967
959- Type srcPtrType =
960- LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
961- Type dstPtrType =
962- LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
968+ Type srcPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
969+ Type dstPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
963970 Value srcPtr = rewriter.create <LLVM::GEPOp>(
964- loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
965-
971+ loc, srcPtrType, elemType, optSrcBuffer->first ,
972+ ArrayRef<Value>({optSrcBuffer->second }));
973+
966974 Value dstPtr = rewriter.create <LLVM::GEPOp>(
967- loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
975+ loc, dstPtrType, elemType, optDstBuffer->first ,
976+ ArrayRef<Value>({optDstBuffer->second }));
968977
969978 rewriter.replaceOpWithNewOp <ROCDL::GlobalLoadLDSOp>(
970979 op, srcPtr, dstPtr, createI32Constant (rewriter, loc, loadWidth),
0 commit comments