@@ -326,21 +326,6 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
326326 auto bitMask = b.lshr (warpMask, b.zext (rewriter.getI64Type (), selectLane));
327327 return b.trunc (i1_ty, bitMask);
328328 }
329-
330- // For direct-to-lds the order of the shared encoding decides the order we
331- // load elements from global memory. This function returns true if the fastest
332- // dim for the sharedEnc is contiguous for the global ptrs/offsets
333- bool isFastedLoadDimContiguous (Value srcPtrOrOffset,
334- MemDescType sharedTy) const {
335- auto fastestDim = triton::gpu::getOrder (sharedTy)[0 ];
336- AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo (srcPtrOrOffset);
337-
338- // This can happen if axis analysis fails (e.g. lit tests).
339- if (axisInfo->getRank () <= fastestDim)
340- return false ;
341-
342- return axisInfo->getContiguity (fastestDim) > 1 ;
343- }
344329};
345330
346331struct LoadOpConversion : public ConvertOpToLLVMPattern <triton::LoadOp>,
@@ -618,18 +603,9 @@ struct BufferLoadToLocalOpConversion
618603 // laneId + swizzleOffset will always stay inside the warp [0,
619604 // threadsPerWarp) because we only swizzle inside a warp
620605 Value swizzledLaneId = b.add (laneId, swizzleLaneOffset);
621- if (isFastedLoadDimContiguous (offset, cast<MemDescType>(dstTy))) {
622- // Because rows are contiguous and we only swizzle inside rows by
623- // swapping elements between lanes we can add laneOffset * vecSize to
624- // the offset to apply the swizzling
625- offsetIn = b.add (offsetIn, b.mul (swizzleLaneOffset,
626- b.i32_val (vecTy.getNumElements ())));
627- } else {
628- // If rows are not contiguous in memory we need to shuffle the
629- // pointers to apply the swizzling to the src pointers
630- offsetIn =
631- targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
632- }
606+ // Shuffle based on swizzleLaneId to apply the swizzling
607+ offsetIn =
608+ targetInfo.shuffleIdx (rewriter, loc, offsetIn, swizzledLaneId);
633609
634610 if (mask) {
635611 pred =
@@ -763,17 +739,9 @@ struct AsyncCopyGlobalToLocalOpConversion
763739 Value swizzledLaneId =
764740 b.add (getLaneId (rewriter, loc), swizzledLaneOffsets[i]);
765741
766- if (isFastedLoadDimContiguous (op.getSrc (), cast<MemDescType>(dstTy))) {
767- // Because rows are contiguous and we only swizzle inside rows by
768- // swapping elements between lanes we can move the vecTy typed src
769- // pointer by laneOffset elements to apply the swizzling.
770- srcPtr =
771- b.gep (srcPtr.getType (), vecTy, srcPtr, swizzledLaneOffsets[i]);
772- } else {
773- // If rows are not contiguous in memory we need to shuffle the
774- // pointers to apply the swizzling to the src pointers
775- srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
776- }
742+ // Shuffle based on swizzleLaneId to apply the swizzling
743+ srcPtr = targetInfo.shuffleIdx (rewriter, loc, srcPtr, swizzledLaneId);
744+
777745 if (!maskElements.empty ()) {
778746 pred =
779747 shuffleMask (rewriter, b, loc, targetInfo, swizzledLaneId, pred);
0 commit comments