@@ -239,21 +239,19 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
239
239
}
240
240
}
241
241
242
- // Emits the computation to get the lane index which holds the source
242
+ // Emits the computation to get the lane id offset which holds the source
243
243
// pointers/offsets we need to store to shared memory
244
- Value emitSwizzledLaneIndex (RewriterBase &rewriter, TritonLLVMOpBuilder &b,
245
- Location loc, Value coalescedShmem,
246
- Value swizzledShmem, Value vecBytes) const {
244
+ Value emitSwizzledLaneOffset (RewriterBase &rewriter, TritonLLVMOpBuilder &b,
245
+ Location loc, Value coalescedShmem,
246
+ Value swizzledShmem, Value vecBytes) const {
247
247
// Compute the laneOffset based on the difference in elements between
248
248
// the two shmem addresses. laneOffset will be negative for half the
249
249
// lanes because a smaller laneId might hold our global_ptr.
250
250
auto coalescedAddr = b.ptrtoint (i64_ty, coalescedShmem);
251
251
auto swizzledAddr = b.ptrtoint (i64_ty, swizzledShmem);
252
252
auto diff = b.trunc (i32_ty, b.sub (swizzledAddr, coalescedAddr));
253
253
Value laneOffset = b.sdiv (diff, vecBytes);
254
- // laneId + laneOffset will always stay inside the warp [0,
255
- // threadsPerWarp) because we only swizzle inside a warp
256
- return b.add (getLaneId (rewriter, loc), laneOffset);
254
+ return laneOffset;
257
255
}
258
256
259
257
// Swizzle the mask (1bit) based on selectLane via ballot
@@ -266,6 +264,21 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
266
264
auto bitMask = b.lshr (warpMask, b.zext (rewriter.getI64Type (), selectLane));
267
265
return b.trunc (i1_ty, bitMask);
268
266
}
267
+
268
+ // For direct-to-lds the order of the shared encoding decides the order we
269
+ // load elements from global memory. This function returns true if the fastest
270
+ // dim for the sharedEnc is contiguous for the global ptrs/offsets
271
+ bool isFastedLoadDimContiguous (Value srcPtrOrOffset,
272
+ MemDescType sharedTy) const {
273
+ auto fastestDim = triton::gpu::getOrder (sharedTy)[0 ];
274
+ AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo (srcPtrOrOffset);
275
+
276
+ // This can happen if axis analysis fails (e.g. lit tests).
277
+ if (axisInfo->getRank () <= fastestDim)
278
+ return false ;
279
+
280
+ return axisInfo->getContiguity (fastestDim) > 1 ;
281
+ }
269
282
};
270
283
271
284
struct LoadOpConversion : public ConvertOpToLLVMPattern <triton::LoadOp>,
@@ -542,11 +555,26 @@ struct BufferLoadToLocalOpConversion
542
555
543
556
if (hasSwizzling) {
544
557
// Apply swizzling to the src offsets
545
- Value swizzledLaneId =
546
- emitSwizzledLaneIndex (rewriter, b, loc, coalescedShmemAddr[i],
547
- swizzledShmemAddr[i], vecBytesVal);
548
- offsetIn =
549
- targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
558
+ Value laneOffset =
559
+ emitSwizzledLaneOffset (rewriter, b, loc, coalescedShmemAddr[i],
560
+ swizzledShmemAddr[i], vecBytesVal);
561
+ // laneId + laneOffset will always stay inside the warp [0,
562
+ // threadsPerWarp) because we only swizzle inside a warp
563
+ Value swizzledLaneId = b.add (getLaneId (rewriter, loc), laneOffset);
564
+
565
+ if (isFastedLoadDimContiguous (offset, cast<MemDescType>(dstTy))) {
566
+ // Because rows are contiguous and we only swizzle inside rows by
567
+ // swapping elements between lanes we can add laneOffset * vecSize to
568
+ // the offset to apply the swizzling
569
+ offsetIn = b.add (
570
+ offsetIn, b.mul (laneOffset, b.i32_val (vecTy.getNumElements ())));
571
+ } else {
572
+ // If rows are not contiguous in memory we need to shuffle the
573
+ // pointers to apply the swizzling to the src pointers
574
+ offsetIn =
575
+ targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
576
+ }
577
+
550
578
if (mask) {
551
579
pred =
552
580
shuffleMask (rewriter, b, loc, targetInfo, swizzledLaneId, pred);
@@ -666,10 +694,23 @@ struct AsyncCopyGlobalToLocalOpConversion
666
694
667
695
if (hasSwizzling) {
668
696
// Apply swizzling to the src pointers
669
- Value swizzledLaneId =
670
- emitSwizzledLaneIndex (rewriter, b, loc, coalescedShmemAddr[i],
671
- swizzledShmemAddr[i], vecBytesVal);
672
- srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
697
+ Value laneOffset =
698
+ emitSwizzledLaneOffset (rewriter, b, loc, coalescedShmemAddr[i],
699
+ swizzledShmemAddr[i], vecBytesVal);
700
+ // laneId + laneOffset will always stay inside the warp [0,
701
+ // threadsPerWarp) because we only swizzle inside a warp
702
+ Value swizzledLaneId = b.add (getLaneId (rewriter, loc), laneOffset);
703
+
704
+ if (isFastedLoadDimContiguous (op.getSrc (), cast<MemDescType>(dstTy))) {
705
+ // Because rows are contiguous and we only swizzle inside rows by
706
+ // swapping elements between lanes we can move the vecTy typed src
707
+ // pointer by laneOffset elements to apply the swizzling.
708
+ srcPtr = b.gep (srcPtr.getType (), vecTy, srcPtr, laneOffset);
709
+ } else {
710
+ // If rows are not contiguous in memory we need to shuffle the
711
+ // pointers to apply the swizzling to the src pointers
712
+ srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
713
+ }
673
714
if (!maskElements.empty ()) {
674
715
pred =
675
716
shuffleMask (rewriter, b, loc, targetInfo, swizzledLaneId, pred);
0 commit comments