@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283
284284    auto  srcMemrefType = cast<MemRefType>(op.getSrcMemref ().getType ());
285285    Value srcPtr =
286-         getStridedElementPtr (b.getLoc (), srcMemrefType, adaptor. getSrcMemref () ,
287-                              adaptor.getIndices (), rewriter );
286+         getStridedElementPtr (rewriter,  b.getLoc (), srcMemrefType,
287+                              adaptor.getSrcMemref (), adaptor. getIndices () );
288288    Value ldMatrixResult = b.create <NVVM::LdMatrixOp>(
289289        ldMatrixResultType, srcPtr,
290290        /* num=*/  op.getNumTiles (),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661661    Location loc = op.getLoc ();
662662    auto  dstMemrefType = cast<MemRefType>(op.getDst ().getType ());
663663    Value dstPtr =
664-         getStridedElementPtr (b.getLoc (), dstMemrefType, adaptor. getDst () ,
665-                              adaptor.getDstIndices (), rewriter );
664+         getStridedElementPtr (rewriter,  b.getLoc (), dstMemrefType,
665+                              adaptor.getDst (), adaptor. getDstIndices () );
666666    FailureOr<unsigned > dstAddressSpace =
667667        getTypeConverter ()->getMemRefAddressSpace (dstMemrefType);
668668    if  (failed (dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676676      return  rewriter.notifyMatchFailure (
677677          loc, " source memref address space not convertible to integer"  );
678678
679-     Value scrPtr = getStridedElementPtr (loc, srcMemrefType, adaptor.getSrc (),
680-                                         adaptor.getSrcIndices (), rewriter);
679+     Value scrPtr =
680+         getStridedElementPtr (rewriter, loc, srcMemrefType, adaptor.getSrc (),
681+                              adaptor.getSrcIndices ());
681682    //  Intrinsics takes a global pointer so we need an address space cast.
682683    auto  srcPointerGlobalType = LLVM::LLVMPointerType::get (
683684        op->getContext (), NVVM::NVVMMemorySpace::kGlobalMemorySpace );
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814815    MemRefType mbarrierMemrefType =
815816        nvgpu::getMBarrierMemrefType (rewriter.getContext (), mbarType);
816817    return  ConvertToLLVMPattern::getStridedElementPtr (
817-         b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter );
818+         rewriter,  b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId});
818819  }
819820};
820821
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995996                  ConversionPatternRewriter &rewriter) const  override  {
996997    ImplicitLocOpBuilder b (op->getLoc (), rewriter);
997998    auto  srcMemrefType = cast<MemRefType>(op.getDst ().getType ());
998-     Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
999-                                       adaptor.getDst (), {}, rewriter );
999+     Value dest = getStridedElementPtr (rewriter,  op->getLoc (), srcMemrefType,
1000+                                       adaptor.getDst (), {});
10001001    Value barrier =
10011002        getMbarrierPtr (b, op.getBarriers ().getType (), adaptor.getBarriers (),
10021003                       adaptor.getMbarId (), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
10211022                  ConversionPatternRewriter &rewriter) const  override  {
10221023    ImplicitLocOpBuilder b (op->getLoc (), rewriter);
10231024    auto  srcMemrefType = cast<MemRefType>(op.getSrc ().getType ());
1024-     Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
1025-                                       adaptor.getSrc (), {}, rewriter );
1025+     Value dest = getStridedElementPtr (rewriter,  op->getLoc (), srcMemrefType,
1026+                                       adaptor.getSrc (), {});
10261027    SmallVector<Value> coords = adaptor.getCoordinates ();
10271028    for  (auto  [index, value] : llvm::enumerate (coords)) {
10281029      coords[index] = truncToI32 (b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
10831084    Value leadDim = makeConst (leadDimVal);
10841085
10851086    Value baseAddr = getStridedElementPtr (
1086-         op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1087-         adaptor.getTensor (), {}, rewriter );
1087+         rewriter,  op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1088+         adaptor.getTensor (), {});
10881089    Value basePtr = b.create <LLVM::PtrToIntOp>(ti64, baseAddr);
10891090    //  Just use 14 bits for base address
10901091    Value basePtr14bit = shiftRight (shiftLeft (basePtr, 46 ), 50 );
0 commit comments