@@ -913,60 +913,49 @@ struct GatherToLDSOpLowering
913913 LogicalResult
914914 matchAndRewrite (GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
915915 ConversionPatternRewriter &rewriter) const override {
916+ if (chipset < kGfx942 )
917+ return op.emitOpError (" chipset not supported" );
918+
916919 Location loc = op.getLoc ();
917920
918- auto elemType = cast<MemRefType>(op.getDst ().getType ()).getElementType ();
919- size_t elemSizeInBits = elemType.getIntOrFloatBitWidth ();
920- if (elemSizeInBits % 8 != 0 )
921- return op.emitOpError (" element size must be a multiple of 8" );
921+ auto srcMemRefType = cast<MemRefType>(op.getSrc ().getType ());
922+ auto dstMemRefType = cast<MemRefType>(op.getSrc ().getType ());
922923
923924 // TODO: instead of only transfering one element per thread, we could
924925 // augment it to transfer multiple elements per thread by issuing multiple
925926 // `global_load_lds` instructions.
926- auto loadWidth = elemSizeInBits / 8 ;
927-
928- if (chipset < kGfx942 )
929- return op.emitOpError (" chipset not supported" );
927+ size_t loadWidth;
928+ Type transferType = op.getTransferType ();
929+ if (auto transferVectorType = dyn_cast<VectorType>(transferType))
930+ loadWidth = transferVectorType.getNumElements () *
931+ transferVectorType.getElementTypeBitWidth () / 8 ;
932+ else
933+ loadWidth = transferType.getIntOrFloatBitWidth () / 8 ;
930934
931935 // Currently only 1, 2, and 4 byte loads are supported.
932- if (!( loadWidth == 1 || loadWidth == 2 || loadWidth == 4 ) )
936+ if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4 )
933937 return op.emitOpError (" chipset unsupported element size" );
934938
935- // Return pair of {base pointer, linearized index}.
936- auto getBasePtrAndLinearizedIndex =
937- [&](Value memref, MemRefType memrefType,
938- ValueRange indices) -> std::optional<std::pair<Value, Value>> {
939- MemRefDescriptor memRefDescriptor (memref);
940- int64_t offset = 0 ;
941- SmallVector<int64_t , 5 > strides;
942- if (failed (memrefType.getStridesAndOffset (strides, offset)))
943- return {};
944- return std::make_pair (
945- memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (),
946- memrefType),
947- getLinearIndexI32 (rewriter, loc, memRefDescriptor, indices, strides));
939+ auto convertIndices =
940+ [&](ValueRange indices) -> SmallVector<Value, 4 > {
941+ SmallVector<Value, 4 > convertedIndices;
942+
943+ for (Value index : indices) {
944+ Type convertedType = getTypeConverter ()->convertType (index.getType ());
945+ auto convertedIndex = rewriter.create <LLVM::ConstantOp>(
946+ loc, convertedType,
947+ rewriter.getIntegerAttr (convertedType, 0 ));
948+ convertedIndices.push_back (convertedIndex);
949+ }
950+ return convertedIndices;
948951 };
949952
950- auto optSrcBuffer = getBasePtrAndLinearizedIndex (
951- adaptor.getSrc (), cast<MemRefType>(op.getSrc ().getType ()),
952- op.getSrcIndices ());
953- if (!optSrcBuffer)
954- return op.emitOpError (" failed to flatten source memref indices" );
955- auto optDstBuffer = getBasePtrAndLinearizedIndex (
956- adaptor.getDst (), cast<MemRefType>(op.getDst ().getType ()),
957- op.getDstIndices ());
958- if (!optDstBuffer)
959- return op.emitOpError (" failed to flatten destination memref indices" );
960-
961- Type srcPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
962- Type dstPtrType = LLVM::LLVMPointerType::get (rewriter.getContext (), 3 );
963- Value srcPtr = rewriter.create <LLVM::GEPOp>(
964- loc, srcPtrType, elemType, optSrcBuffer->first ,
965- ArrayRef<Value>({optSrcBuffer->second }));
966-
967- Value dstPtr = rewriter.create <LLVM::GEPOp>(
968- loc, dstPtrType, elemType, optDstBuffer->first ,
969- ArrayRef<Value>({optDstBuffer->second }));
953+ Value srcPtr =
954+ getStridedElementPtr (loc, srcMemRefType, adaptor.getSrc (),
955+ convertIndices (op.getSrcIndices ()), rewriter);
956+ Value dstPtr =
957+ getStridedElementPtr (loc, dstMemRefType, adaptor.getDst (),
958+ convertIndices (op.getDstIndices ()), rewriter);
970959
971960 rewriter.replaceOpWithNewOp <ROCDL::GlobalLoadLDSOp>(
972961 op, srcPtr, dstPtr, createI32Constant (rewriter, loc, loadWidth),
0 commit comments