@@ -542,18 +542,20 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
542542 return unpackLLVector (loc, valsVec, rewriter);
543543 }
544544 };
545+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
545546 return lowerLdSt (loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
546- calcPaddedOffset, affineOffset, maskSpanAffineOffset,
547- rewriter, targetInfo, {}, emitLdSt);
547+ calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
548+ warpId, rewriter, targetInfo, {}, emitLdSt);
548549}
549550
550551SmallVector<Value> lowerLdSt (
551552 Location loc, MLIRContext *ctx, LinearLayout cvt,
552553 ArrayRef<Value> valsArray, // Input for store, output for load
553554 Type llvmElemTy, Value smemBase,
554555 std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
555- uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
556- const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
556+ uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
557+ RewriterBase &rewriter, const TargetInfoBase &targetInfo,
558+ std::optional<int> maybeMaxVecElems,
557559 std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
558560 Value, int , VectorType)>
559561 lowerInst) {
@@ -599,7 +601,6 @@ SmallVector<Value> lowerLdSt(
599601 zerosLike (LinearLayout::identity1D (bitwidth / 8 , kReg , kOffset ));
600602 auto i8AddrLayout = i8Tile * addrLayout;
601603
602- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
603604 auto regBaseI8 =
604605 applyLinearLayout (
605606 loc, rewriter, i8AddrLayout,
@@ -2022,16 +2023,17 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
20222023 };
20232024
20242025 auto noPaddingOffset = [](Value v) { return v; };
2026+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
20252027 lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20262028 /* calcPaddedOffset=*/ noPaddingOffset, /* affineOffset=*/ b.i32_val (0 ),
2027- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo,
2029+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter, targetInfo,
20282030 /* maybeMaxVecElems=*/ {}, emitSt);
20292031 b.barrier ();
20302032 resultVals = lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20312033 /* calcPaddedOffset=*/ noPaddingOffset,
20322034 /* affineOffset=*/ b.i32_val (0 ),
2033- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo ,
2034- /* maybeMaxVecElems=*/ {}, emitLd);
2035+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter ,
2036+ targetInfo, /* maybeMaxVecElems=*/ {}, emitLd);
20352037
20362038 // Create the result struct and replace the operation
20372039 Value resultStruct =
0 commit comments