@@ -446,21 +446,6 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
446
446
auto bitMask = b.lshr (warpMask, b.zext (rewriter.getI64Type (), selectLane));
447
447
return b.trunc (i1_ty, bitMask);
448
448
}
449
-
450
- // For direct-to-lds the order of the shared encoding decides the order we
451
- // load elements from global memory. This function returns true if the fastest
452
- // dim for the sharedEnc is contiguous for the global ptrs/offsets
453
- bool isFastedLoadDimContiguous (Value srcPtrOrOffset,
454
- MemDescType sharedTy) const {
455
- auto fastestDim = triton::gpu::getOrder (sharedTy)[0 ];
456
- AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo (srcPtrOrOffset);
457
-
458
- // This can happen if axis analysis fails (e.g. lit tests).
459
- if (axisInfo->getRank () <= fastestDim)
460
- return false ;
461
-
462
- return axisInfo->getContiguity (fastestDim) > 1 ;
463
- }
464
449
};
465
450
466
451
struct LoadOpConversion : public ConvertOpToLLVMPattern <triton::LoadOp>,
@@ -732,18 +717,9 @@ struct BufferLoadToLocalOpConversion
732
717
// laneId + swizzleOffset will always stay inside the warp [0,
733
718
// threadsPerWarp) because we only swizzle inside a warp
734
719
Value swizzledLaneId = b.add (laneId, swizzleLaneOffset);
735
- if (isFastedLoadDimContiguous (offset, cast<MemDescType>(dstTy))) {
736
- // Because rows are contiguous and we only swizzle inside rows by
737
- // swapping elements between lanes we can add laneOffset * vecSize to
738
- // the offset to apply the swizzling
739
- offsetIn = b.add (offsetIn, b.mul (swizzleLaneOffset,
740
- b.i32_val (vecTy.getNumElements ())));
741
- } else {
742
- // If rows are not contiguous in memory we need to shuffle the
743
- // pointers to apply the swizzling to the src pointers
744
- offsetIn =
745
- targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
746
- }
720
+ // Shuffle based on swizzleLaneId to apply the swizzling
721
+ offsetIn =
722
+ targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
747
723
748
724
if (mask) {
749
725
pred =
@@ -877,17 +853,9 @@ struct AsyncCopyGlobalToLocalOpConversion
877
853
Value swizzledLaneId =
878
854
b.add (getLaneId (rewriter, loc), swizzledLaneOffsets[i]);
879
855
880
- if (isFastedLoadDimContiguous (op.getSrc (), cast<MemDescType>(dstTy))) {
881
- // Because rows are contiguous and we only swizzle inside rows by
882
- // swapping elements between lanes we can move the vecTy typed src
883
- // pointer by laneOffset elements to apply the swizzling.
884
- srcPtr =
885
- b.gep (srcPtr.getType (), vecTy, srcPtr, swizzledLaneOffsets[i]);
886
- } else {
887
- // If rows are not contiguous in memory we need to shuffle the
888
- // pointers to apply the swizzling to the src pointers
889
- srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
890
- }
856
+ // Shuffle based on swizzleLaneId to apply the swizzling
857
+ srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
858
+
891
859
if (!maskElements.empty ()) {
892
860
pred =
893
861
shuffleMask (rewriter, b, loc, targetInfo, swizzledLaneId, pred);
0 commit comments