@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283
284284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref ().getType ());
285285 Value srcPtr =
286- getStridedElementPtr (b.getLoc (), srcMemrefType, adaptor. getSrcMemref () ,
287- adaptor.getIndices (), rewriter );
286+ getStridedElementPtr (rewriter, b.getLoc (), srcMemrefType,
287+ adaptor.getSrcMemref (), adaptor. getIndices () );
288288 Value ldMatrixResult = b.create <NVVM::LdMatrixOp>(
289289 ldMatrixResultType, srcPtr,
290290 /* num=*/ op.getNumTiles (),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661661 Location loc = op.getLoc ();
662662 auto dstMemrefType = cast<MemRefType>(op.getDst ().getType ());
663663 Value dstPtr =
664- getStridedElementPtr (b.getLoc (), dstMemrefType, adaptor. getDst () ,
665- adaptor.getDstIndices (), rewriter );
664+ getStridedElementPtr (rewriter, b.getLoc (), dstMemrefType,
665+ adaptor.getDst (), adaptor. getDstIndices () );
666666 FailureOr<unsigned > dstAddressSpace =
667667 getTypeConverter ()->getMemRefAddressSpace (dstMemrefType);
668668 if (failed (dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676676 return rewriter.notifyMatchFailure (
677677 loc, " source memref address space not convertible to integer" );
678678
679- Value scrPtr = getStridedElementPtr (loc, srcMemrefType, adaptor.getSrc (),
680- adaptor.getSrcIndices (), rewriter);
679+ Value scrPtr =
680+ getStridedElementPtr (rewriter, loc, srcMemrefType, adaptor.getSrc (),
681+ adaptor.getSrcIndices ());
681682 // Intrinsics takes a global pointer so we need an address space cast.
682683 auto srcPointerGlobalType = LLVM::LLVMPointerType::get (
683684 op->getContext (), NVVM::NVVMMemorySpace::kGlobalMemorySpace );
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814815 MemRefType mbarrierMemrefType =
815816 nvgpu::getMBarrierMemrefType (rewriter.getContext (), mbarType);
816817 return ConvertToLLVMPattern::getStridedElementPtr (
817- b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter );
818+ rewriter, b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId});
818819 }
819820};
820821
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995996 ConversionPatternRewriter &rewriter) const override {
996997 ImplicitLocOpBuilder b (op->getLoc (), rewriter);
997998 auto srcMemrefType = cast<MemRefType>(op.getDst ().getType ());
998- Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
999- adaptor.getDst (), {}, rewriter );
999+ Value dest = getStridedElementPtr (rewriter, op->getLoc (), srcMemrefType,
1000+ adaptor.getDst (), {});
10001001 Value barrier =
10011002 getMbarrierPtr (b, op.getBarriers ().getType (), adaptor.getBarriers (),
10021003 adaptor.getMbarId (), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
10211022 ConversionPatternRewriter &rewriter) const override {
10221023 ImplicitLocOpBuilder b (op->getLoc (), rewriter);
10231024 auto srcMemrefType = cast<MemRefType>(op.getSrc ().getType ());
1024- Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
1025- adaptor.getSrc (), {}, rewriter );
1025+ Value dest = getStridedElementPtr (rewriter, op->getLoc (), srcMemrefType,
1026+ adaptor.getSrc (), {});
10261027 SmallVector<Value> coords = adaptor.getCoordinates ();
10271028 for (auto [index, value] : llvm::enumerate (coords)) {
10281029 coords[index] = truncToI32 (b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
10831084 Value leadDim = makeConst (leadDimVal);
10841085
10851086 Value baseAddr = getStridedElementPtr (
1086- op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1087- adaptor.getTensor (), {}, rewriter );
1087+ rewriter, op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1088+ adaptor.getTensor (), {});
10881089 Value basePtr = b.create <LLVM::PtrToIntOp>(ti64, baseAddr);
10891090 // Just use 14 bits for base address
10901091 Value basePtr14bit = shiftRight (shiftLeft (basePtr, 46 ), 50 );
0 commit comments