@@ -589,18 +589,20 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
589589 return unpackLLVector (loc, valsVec, rewriter);
590590 }
591591 };
592+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
592593 return lowerLdSt (loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
593- calcPaddedOffset, affineOffset, maskSpanAffineOffset,
594- rewriter, targetInfo, {}, emitLdSt);
594+ calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
595+ warpId, rewriter, targetInfo, {}, emitLdSt);
595596}
596597
597598SmallVector<Value> lowerLdSt (
598599 Location loc, MLIRContext *ctx, LinearLayout cvt,
599600 ArrayRef<Value> valsArray, // Input for store, output for load
600601 Type llvmElemTy, Value smemBase,
601602 std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
602- uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
603- const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
603+ uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
604+ RewriterBase &rewriter, const TargetInfoBase &targetInfo,
605+ std::optional<int> maybeMaxVecElems,
604606 std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
605607 Value, int , VectorType)>
606608 lowerInst) {
@@ -646,7 +648,6 @@ SmallVector<Value> lowerLdSt(
646648 zerosLike (LinearLayout::identity1D (bitwidth / 8 , kReg , kOffset ));
647649 auto i8AddrLayout = i8Tile * addrLayout;
648650
649- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
650651 auto regBaseI8 =
651652 applyLinearLayout (
652653 loc, rewriter, i8AddrLayout,
@@ -1689,16 +1690,17 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
16891690 };
16901691
16911692 auto noPaddingOffset = [](Value v) { return v; };
1693+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
16921694 lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
16931695 /* calcPaddedOffset=*/ noPaddingOffset, /* affineOffset=*/ b.i32_val (0 ),
1694- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo,
1696+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter, targetInfo,
16951697 /* maybeMaxVecElems=*/ {}, emitSt);
16961698 b.barrier ();
16971699 resultVals = lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
16981700 /* calcPaddedOffset=*/ noPaddingOffset,
16991701 /* affineOffset=*/ b.i32_val (0 ),
1700- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo ,
1701- /* maybeMaxVecElems=*/ {}, emitLd);
1702+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter ,
1703+ targetInfo, /* maybeMaxVecElems=*/ {}, emitLd);
17021704
17031705 // Create the result struct and replace the operation
17041706 Value resultStruct =
0 commit comments