@@ -420,31 +420,90 @@ struct BufferLoadToLocalOpConversion
420420 if (llOther)
421421 otherElems = unpackLLElements (loc, llOther, rewriter);
422422
423- // buffer_load into LDS does not support per lane offsets.
424- // We need to ensure that we write coalesced into shared memory.
425423 auto dstTy = op.getDest ().getType ();
426- if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory (rewriter, ptrType, dstTy,
427- vec)) {
424+ auto sharedEnc = cast<SwizzledSharedEncodingAttr>(dstTy.getEncoding ());
425+
426+ // buffer_load into LDS does not support per lane shared offsets. We need to
427+ // ensure that we write coalesced into shared memory.
428+ //
429+ // For *non* swizzled shared encodings we check if they result in
430+ // coalesced writes and can then lower them directly to the intrinsics.
431+ //
432+ // For swizzled shared encodings we need to transfer the swizzling to the
433+ // source pointers. For now this is done by swizzling the pointers between
434+ // the lane of a warp via permute. This only works if the swizzle pattern
435+ // does not exchange elements between warps which holds for all our swizzle
436+ // patterns. There is still a check performed to not silently produce wrong
437+ // results if we invalidate the condition in the future
438+
439+ bool hasSwizzling = sharedEnc.getMaxPhase () != 1 ;
440+
441+ // Compute the blocked -> shared linear layout to check preconditions
442+ auto shape = ptrType.getShape ();
443+ LinearLayout srcLayout =
444+ triton::gpu::toLinearLayout (shape, ptrType.getEncoding ());
445+ LinearLayout sharedLayout =
446+ triton::gpu::toLinearLayout (shape, dstTy.getEncoding ());
447+ LinearLayout srcToSharedLayout = srcLayout.invertAndCompose (sharedLayout);
448+
449+ unsigned threadsPerWarp = lookupThreadsPerWarp (rewriter);
450+ if (!hasSwizzling && !LLVM::AMD::canCoalesceWriteIntoSharedMemory (
451+ rewriter, srcToSharedLayout, threadsPerWarp)) {
452+ return rewriter.notifyMatchFailure (
453+ op, " does not write coalesced into LDS and is not swizzled" );
454+ }
455+
456+ if (hasSwizzling && !LLVM::AMD::doesSwizzleInsideWarp (
457+ rewriter, srcToSharedLayout, threadsPerWarp)) {
428458 return rewriter.notifyMatchFailure (op,
429- " does not write coalesced into LDS " );
459+ " does swizzle across warp boundaries " );
430460 }
431461
432462 auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
433463 auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct (
434464 loc, llDst, resElemTy, rewriter);
435465
436- // First we determine the vector size per load and collect the
437- // shared addresses. This will only emit the address calculation and not the
438- // actual loads
466+ auto emitSharedAddresses = [&](RankedTensorType srcTy, MemDescType dstTy,
467+ SmallVector<Value> &shmemAddrs,
468+ VectorType &vecTy) {
469+ bool ok = emitTransferBetweenRegistersAndShared (
470+ ptrType, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
471+ [&](VectorType vecTy_, Value shmemAddr) {
472+ vecTy = vecTy_;
473+ shmemAddrs.push_back (shmemAddr);
474+ });
475+ assert (ok);
476+ };
477+
478+ // Determine the vector size per load and collect the shared addresses. This
479+ // will only emit the address calculation and not the actual loads.
480+ // For swizzled loads we get the non swizzled/coalesced shared addresses
481+ // from a temporary non swizzled layout. Those addresses will be used as the
482+ // store addresses. Additionally, we compute the swizzled shared memory
483+ // addresses which will be used to compute which lane holds the global ptr
484+ // to the coalesced address
439485 VectorType vecTy;
440- SmallVector<Value> shmemAddrs;
441- bool ok = emitTransferBetweenRegistersAndShared (
442- ptrType, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
443- [&](VectorType vecTy_, Value shmemAddr) {
444- vecTy = vecTy_;
445- shmemAddrs.push_back (shmemAddr);
446- });
447- assert (ok);
486+ SmallVector<Value> coalescedShmemAddr;
487+ SmallVector<Value> swizzledShmemAddr;
488+
489+ if (!hasSwizzling) {
490+ emitSharedAddresses (ptrType, dstTy, coalescedShmemAddr, vecTy);
491+ } else {
492+ emitSharedAddresses (ptrType, dstTy, swizzledShmemAddr, vecTy);
493+ // Create non swizzled/coalesced encoding
494+ auto dstEnc = cast<SwizzledSharedEncodingAttr>(dstTy.getEncoding ());
495+ auto flatSharedEnc = SwizzledSharedEncodingAttr::get (
496+ getContext (), dstEnc.getVec (), 1 , 1 , dstEnc.getOrder (),
497+ dstEnc.getCTALayout ());
498+ auto flatDstTy =
499+ MemDescType::get (dstTy.getShape (), dstTy.getElementType (),
500+ flatSharedEnc, dstTy.getMemorySpace ());
501+ VectorType coalescedVecTy;
502+ emitSharedAddresses (ptrType, flatDstTy, coalescedShmemAddr,
503+ coalescedVecTy);
504+ assert (coalescedVecTy == vecTy);
505+ }
506+ assert (vecTy.getNumElements () == vec);
448507
449508 int vecBits = vecTy.getNumElements () * vecTy.getElementTypeBitWidth ();
450509 if (!targetInfo.supportsDirectToLdsLoadBitWidth (vecBits)) {
@@ -462,17 +521,43 @@ struct BufferLoadToLocalOpConversion
462521 // based on the collected shared addresses and vector size
463522 Value rsrcDesc = bufferEmitter.createResourceDescriptor (llPtr, llStride);
464523
465- for (int i = 0 ; i < shmemAddrs .size (); i++) {
524+ for (int i = 0 ; i < coalescedShmemAddr .size (); i++) {
466525 auto srcIdx = i * vec;
467526 auto offsetIn = offsetElems[srcIdx];
468-
469527 Value pred = mask ? maskElems[srcIdx] : b.true_val ();
528+
529+ if (hasSwizzling) {
530+ // Compute the laneOffset based on the difference in elements between
531+ // the two shmem addresses. laneOffset will be negative for half the
532+ // lanes because a smaller laneId might hold our global_ptr.
533+ auto coalescedAddr = b.ptrtoint (i64_ty, coalescedShmemAddr[i]);
534+ auto swizzledAddr = b.ptrtoint (i64_ty, swizzledShmemAddr[i]);
535+ auto diff = b.trunc (i32_ty, b.sub (swizzledAddr, coalescedAddr));
536+ Value laneOffset = b.sdiv (diff, vecBytesVal);
537+ // selectLane will always stay inside the warp [0,
538+ // threadsPerWarp) because we only swizzle inside a warp
539+ Value selectLane = b.add (getLaneId (rewriter, loc), laneOffset);
540+
541+ offsetIn = targetInfo.shuffleIdx (rewriter, loc, offsetIn, selectLane);
542+
543+ if (mask) {
544+ // To swizzle the mask we can use ballot and then select the bit based
545+ // on the lane id
546+ auto warpMask =
547+ targetInfo.ballot (rewriter, loc, rewriter.getI64Type (), pred);
548+ // Extract the selectLane bit
549+ auto bitMask =
550+ b.lshr (warpMask, b.zext (rewriter.getI64Type (), selectLane));
551+ pred = b.trunc (i1_ty, bitMask);
552+ }
553+ }
554+
470555 bufferEmitter.emitLoadToLds (vecTy, vecBytesVal, rsrcDesc, offsetIn,
471- shmemAddrs [i], pred, op.getCache ());
556+ coalescedShmemAddr [i], pred, op.getCache ());
472557 if (!otherElems.empty ()) {
473558 Value storeVal = packElementRangeIntoVector (
474559 rewriter, this ->getTypeConverter (), loc, vecTy, otherElems, srcIdx);
475- llStore (rewriter, loc, shmemAddrs [i], storeVal,
560+ llStore (rewriter, loc, coalescedShmemAddr [i], storeVal,
476561 b.icmp_ne (maskElems[srcIdx], b.true_val ()), op.getCache ());
477562 }
478563 }
@@ -534,11 +619,19 @@ struct AsyncCopyGlobalToLocalOpConversion
534619 auto maskElements = getMaskElemsAndUpdateVeclen (
535620 rewriter, loc, adaptor.getMask (), op.getMask (), maxVec);
536621
622+ auto shape = srcTy.getShape ();
623+ LinearLayout srcLayout =
624+ triton::gpu::toLinearLayout (shape, srcTy.getEncoding ());
625+ LinearLayout sharedLayout =
626+ triton::gpu::toLinearLayout (shape, dstTy.getEncoding ());
627+ LinearLayout srcToSharedLayout = srcLayout.invertAndCompose (sharedLayout);
628+
537629 // global.load.lds does not support per lane offsets.
538630 // We need to ensure that we write coalesced into shared memory. This means
539631 // that the kLane dim needs to be contigeous based on the vector size.
540- if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory (rewriter, srcTy, dstTy,
541- maxVec)) {
632+ unsigned threadsPerWarp = lookupThreadsPerWarp (rewriter);
633+ if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory (
634+ rewriter, srcToSharedLayout, threadsPerWarp)) {
542635 return rewriter.notifyMatchFailure (op,
543636 " does not write coalesced into LDS" );
544637 }
0 commit comments