@@ -169,10 +169,139 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
169169 return ret;
170170}
171171
172+ namespace {
173+
174+ Value getSmemVecAddr (RankedTensorType registerTy,
175+ triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
176+ Location loc, RewriterBase &rewriter,
177+ const LinearLayout ®ToSharedLayout, Value regId,
178+ Value laneId, Value warpId,
179+ const SharedMemoryObject &smemObj) {
180+ MLIRContext *ctx = rewriter.getContext ();
181+ StringAttr kBlock = str_attr (" block" );
182+ StringAttr kRegister = str_attr (" register" );
183+ StringAttr kLane = str_attr (" lane" );
184+ StringAttr kWarp = str_attr (" warp" );
185+ auto shape = sharedTy.getShape ();
186+ auto rank = shape.size ();
187+ auto allocShape = sharedTy.getAllocShape ();
188+ auto sharedEnc =
189+ dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding ());
190+
191+ auto smemBase = smemObj.getBase ();
192+ auto sharedOrder = triton::gpu::getOrder (sharedTy.getEncoding ());
193+ auto smemOffsets = smemObj.getOffsets ();
194+ auto smemStrides = smemObj.getStrides ();
195+ Value smemOffset;
196+ // When loading or storing to shared memory, we consider two cases for
197+ // performance reasons:
198+ //
199+ // 1. Non-swizzled shared memory.
200+ // 2. Swizzled shared memory.
201+ //
202+ // Consider lowering `ttg.local_load %a`. In the first case, we can
203+ // directly construct a linear layout using `%a`'s shape and shared memory
204+ // encoding, irrespective of `%a`'s rank or whether it represents a slice of a
205+ // larger tensor.
206+ //
207+ // The method does not apply for swizzled shared memory in some scenarios.
208+ // Key properties of swizzling in Triton are:
209+ //
210+ // - Swizzling applies only to tensors with rank ≥ 2.
211+ // - It is restricted to the last two dimensions of the tensor.
212+ // - These last two dimensions are always treated as the most "minor."
213+ //
214+ // An important edge case arises when `%a` results from `%a = ttg.subview %b`,
215+ // where `%b` is swizzled (and so is `%a`). In this case, constructing a
216+ // layout and determining shared memory offsets using `%a`'s shape is
217+ // incorrect. This is because swizzling depends on the original shape of `%b`,
218+ // which differs from `%a`'s shape. As a result, some locations may fall
219+ // outside `%a`'s contiguous view of memory. Specifically, an element `[i
220+ // (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling,
221+ // where `j'` lies outside `%a`'s shape but still within `%b`'s shape.
222+ //
223+ // We propose case 2 (see comments below), which provides a more general
224+ // solution for all swizzled shared memory scenarios, including the edge case
225+ // mentioned above.
226+ if (/* no swizzling*/ sharedEnc.getMaxPhase () == 1 ||
227+ /* swizzling but same shape*/ shape == allocShape ||
228+ /* swizzling and rank-reduced and rank >= 2*/
229+ (shape == allocShape.take_back (rank) && rank >= 2 )) { // Case 1
230+ // Get the address to load/store. The multi-dim address is (offsetX1, ...,
231+ // offsetXN, block), where the offsets appear in minor-to-major order, and
232+ // we drop_end to drop block, which we know from above will be 0.
233+ smemOffsets = llvm::to_vector (llvm::drop_end (llvm::make_second_range (
234+ applyLinearLayout (loc, rewriter, regToSharedLayout,
235+ {{kRegister , regId},
236+ {kLane , laneId},
237+ {kWarp , warpId},
238+ {kBlock , i32_val (0 )}}))));
239+ // Reorder strides according to `order`. This way they match the
240+ // multi-dimensional offsets in regToSharedLayout.
241+ smemOffset = dot (rewriter, loc, smemOffsets,
242+ applyPermutation (smemStrides, sharedOrder));
243+ } else { // Case 2 -> rank-reduced swizzling
244+ assert (rank >= 2 && " Swizzling only applies to tensors with rank >= 2" );
245+ // We define both tensor offsets and shared memory offsets:
246+ //
247+ // - Tensor offsets: Relative offsets within a given tensor.
248+ // - Shared memory offsets: Absolute offsets within the shared memory.
249+ //
250+ // In Triton, the shared memory layout provides an invertible, one-to-one
251+ // mapping between tensor offsets and shared memory offsets. The `base`
252+ // field of any shared memory object represents both the shared memory
253+ // offset and the tensor offset relative to the original tensor at
254+ // allocation, prior to any subview operations.
255+ //
256+ // To determine the shared memory offsets for a specific register when
257+ // dealing with swizzled and sliced tensors, the process involves:
258+ //
259+ // 1. Retrieving the original tensor's `invertAllocSharedLayout`, which
260+ // maps the allocated tensor's offsets back to shared memory offsets.
261+ // 2. Reconstructing the register's offsets in the allocated tensor by
262+ // summing:
263+ // - The shared memory offsets of the current view's base, and
264+ // - The relative tensor offsets of the register.
265+ //
266+ // This approach ensures that "absolute" tensor offsets can be
267+ // mapped to the correct shared memory addresses using
268+ // `invertAllocSharedLayout`.
269+ std::optional<LinearLayout> regLayout =
270+ triton::gpu::toLinearLayout (shape, registerTy.getEncoding ());
271+ auto allocSharedLayout = triton::gpu::toLinearLayout (
272+ allocShape.take_back (rank), sharedTy.getEncoding (),
273+ elemLlvmTy.getIntOrFloatBitWidth ());
274+ assert (allocSharedLayout.has_value () &&
275+ " Failed to convert layout to linear layout" );
276+ auto invertAllocSharedLayout = allocSharedLayout->invert ();
277+ auto multiDimTensorOffsets =
278+ llvm::to_vector (applyLinearLayout (loc, rewriter, *regLayout,
279+ {{kRegister , regId},
280+ {kLane , laneId},
281+ {kWarp , warpId},
282+ {kBlock , i32_val (0 )}}));
283+ for (auto i = 0 ; i < rank; i++) {
284+ multiDimTensorOffsets[i].second =
285+ add (multiDimTensorOffsets[i].second , smemOffsets[i]);
286+ }
287+ smemOffset = applyLinearLayout (loc, rewriter, invertAllocSharedLayout,
288+ multiDimTensorOffsets)[0 ]
289+ .second ;
290+ Value baseToAllocBaseDist = dot (rewriter, loc, smemOffsets, smemStrides);
291+ smemOffset = sub (smemOffset, baseToAllocBaseDist);
292+ }
293+ auto ptrTy = smemBase.getType ();
294+ auto vecAddr = gep (ptrTy, elemLlvmTy, smemBase, smemOffset);
295+ vecAddr.setInbounds (true );
296+ return vecAddr;
297+ }
298+
299+ } // namespace
300+
172301bool emitTransferBetweenRegistersAndShared (
173302 RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
174- Type elemLlvmTy, std::optional<int32_t > maxVecElems, Value shmemBase,
175- ArrayRef<Value> shmemStrides , Location loc, RewriterBase &rewriter,
303+ Type elemLlvmTy, std::optional<int32_t > maxVecElems,
304+ const SharedMemoryObject &smemObj , Location loc, RewriterBase &rewriter,
176305 const TargetInfoBase &target,
177306 std::function<void (VectorType, Value /* shmemAddr*/ )> perVectorCallback) {
178307 MLIRContext *ctx = rewriter.getContext ();
@@ -230,28 +359,12 @@ bool emitTransferBetweenRegistersAndShared(
230359
231360 int numElems = regToSharedLayout->getInDimSize (kRegister );
232361 auto vecTy = vec_ty (elemLlvmTy, vecElems);
233- auto ptrTy = shmemBase.getType ();
234362 Value zero = i32_val (0 );
235363 SmallVector<Value> ret;
236364 for (int i = 0 ; i < numElems / vecElems; i++) {
237- // Get the address to load/store. The multi-dim address is (offsetX1, ...,
238- // offsetXN, block), where the offsets appear in minor-to-major order, and
239- // we drop_end to drop block, which we know from above will be 0.
240- auto multiDimShmemOffset =
241- llvm::to_vector (llvm::drop_end (llvm::make_second_range (
242- applyLinearLayout (loc, rewriter, *regToSharedLayout,
243- {{kRegister , i32_val (i * vecElems)},
244- {kLane , laneId},
245- {kWarp , warpId},
246- {kBlock , zero}}))));
247-
248- // Reorder strides according to `order`. This way they match the
249- // multi-dimensional offsets in regToSharedLayout.
250- auto sharedOrder = triton::gpu::getOrder (sharedTy.getEncoding ());
251- Value shmemOffset = dot (rewriter, loc, multiDimShmemOffset,
252- applyPermutation (shmemStrides, sharedOrder));
253- auto vecAddr = gep (ptrTy, elemLlvmTy, shmemBase, shmemOffset);
254- vecAddr.setInbounds (true );
365+ auto vecAddr = getSmemVecAddr (
366+ registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout,
367+ i32_val (i * vecElems), laneId, warpId, smemObj);
255368
256369 perVectorCallback (vecTy, vecAddr);
257370 }
@@ -261,14 +374,13 @@ bool emitTransferBetweenRegistersAndShared(
261374SmallVector<Value> loadSharedToDistributed (RankedTensorType dstTy,
262375 triton::gpu::MemDescType srcTy,
263376 Type elemLlvmTy,
264- SharedMemoryObject smemObj,
377+ const SharedMemoryObject & smemObj,
265378 Location loc, RewriterBase &rewriter,
266379 const TargetInfoBase &target) {
267380 SmallVector<Value> ret;
268381 bool success = emitTransferBetweenRegistersAndShared (
269- dstTy, srcTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemObj.getBase (),
270- smemObj.getStrides (), loc, rewriter, target,
271- [&](VectorType vecTy, Value vecAddr) {
382+ dstTy, srcTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemObj, loc,
383+ rewriter, target, [&](VectorType vecTy, Value vecAddr) {
272384 auto vecVal = load (vecTy, vecAddr);
273385 vecVal.setAlignment (vecTy.getNumElements () *
274386 elemLlvmTy.getIntOrFloatBitWidth () / 8 );
@@ -285,14 +397,14 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
285397
286398void storeDistributedToShared (triton::gpu::MemDescType dstTy,
287399 RankedTensorType srcTy, Type elemLlvmTy,
288- ArrayRef<Value> srcVals, Value smemBase,
289- ArrayRef<Value> dstStrides , Location loc,
400+ ArrayRef<Value> srcVals,
401+ const SharedMemoryObject &smemObj , Location loc,
290402 RewriterBase &rewriter,
291403 const TargetInfoBase &target,
292404 std::pair<size_t , Type> *const llvmOpCount) {
293405 bool success = emitTransferBetweenRegistersAndShared (
294- srcTy, dstTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemBase ,
295- dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
406+ srcTy, dstTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemObj, loc ,
407+ rewriter, target, [&](VectorType vecTy, Value vecAddr) {
296408 ArrayRef<Value> vals = srcVals.take_front (vecTy.getNumElements ());
297409 srcVals = srcVals.drop_front (vecTy.getNumElements ());
298410
0 commit comments