@@ -296,6 +296,207 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
296296 DecomposedWarpConversion decomposed,
297297 OpAdaptor adaptor,
298298 ConversionPatternRewriter &rewriter) const ;
299+
300+ SmallVector<Value>
301+ transferWithinBlockImpl (ArrayRef<Value> inVals, ConvertLayoutOp op,
302+ const LinearLayout &srcLayout,
303+ const LinearLayout &dstLayout,
304+ ConversionPatternRewriter &rewriter) const {
305+ MLIRContext *ctx = op.getContext ();
306+ auto loc = op.getLoc ();
307+ auto b = TritonLLVMOpBuilder (loc, rewriter);
308+
309+ StringAttr kRegister = str_attr (" register" );
310+ StringAttr kLane = str_attr (" lane" );
311+ StringAttr kWarp = str_attr (" warp" );
312+ StringAttr kBlock = str_attr (" block" );
313+ StringAttr kOffset = str_attr (" offset" );
314+ StringAttr kIteration = str_attr (" iteration" );
315+
316+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
317+
318+ auto scratchConfig =
319+ getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
320+ auto tensorShapePerCTA = convertType<unsigned , int64_t >(getShapePerCTA (
321+ op.getSrc ().getType ().getEncoding (), op.getType ().getShape ()));
322+ // Input dims: [offset, iteration, block]
323+ // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
324+ LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion (
325+ ctx, tensorShapePerCTA, scratchConfig.repShape , scratchConfig.order );
326+
327+ // Layout for the store from registers to shared memory.
328+ //
329+ // Note: If two threads in the same warp write to the same shmem offset, the
330+ // hardware resolves that without a stall or a bank conflict. Therefore we
331+ // don't need to avoid duplicate writes.
332+ // Input dims: [reg, lane, warp]
333+ // Output dims: [offset, iteration]
334+ bool isStMatrix = targetInfo.canUseStMatrix (
335+ op.getSrc ().getType (), scratchConfig.repShape ,
336+ scratchConfig.paddedRepShape , scratchConfig.order ,
337+ /* swizzleByteSize=*/ 0 );
338+ LinearLayout shmemStoreLayout =
339+ isStMatrix ? chooseStMatrixLayout (ctx, op.getSrc ().getType (),
340+ /* swizzleByteSize=*/ 0 )
341+ : srcLayout.invertAndCompose (sharedLayout);
342+
343+ const int shmemAllocatedNumElems =
344+ getNumScratchElements (scratchConfig.paddedRepShape );
345+ assert (shmemStoreLayout.getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
346+
347+ // Layout for the load from shmem to registers.
348+ LinearLayout shmemLoadLayout = dstLayout.invertAndCompose (sharedLayout);
349+
350+ // Check that the `register` fully determines the `iteration`. That is,
351+ // each thread does exactly the same reads and writes to shmem on each
352+ // iteration, just with different input/output registers.
353+ assert (
354+ shmemStoreLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
355+ assert (
356+ shmemLoadLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
357+
358+ // iteration -> registers
359+ SmallVector<SmallVector<int >> inRegsForIter =
360+ collectRegsForIter (ctx, shmemStoreLayout);
361+ SmallVector<SmallVector<int >> outRegsForIter =
362+ collectRegsForIter (ctx, shmemLoadLayout);
363+
364+ Value smemBase =
365+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
366+ auto sharedPtrTy = smemBase.getType ();
367+ Type elemTy = inVals[0 ].getType ();
368+ auto outSize = shmemLoadLayout.getInDimSize (kRegister );
369+ auto iterations = sharedLayout.getInDimSize (kIteration );
370+ assert (scratchConfig.inVec * iterations <= inVals.size ());
371+ assert (scratchConfig.outVec * iterations <= outSize);
372+
373+ // Check only one dimension has been padded.
374+ // This means the difference between the padded shape and the original shape
375+ // should only be in one dimension, specifically in
376+ // `scratchConfig.order[0]`.
377+ auto rank = scratchConfig.repShape .size ();
378+ for (auto i = 0 ; i < rank; i++) {
379+ if (i == scratchConfig.order [0 ]) {
380+ continue ;
381+ }
382+ assert (scratchConfig.repShape [i] == scratchConfig.paddedRepShape [i]);
383+ }
384+ auto paddedStride = scratchConfig.repShape [scratchConfig.order [0 ]];
385+ auto paddedSize =
386+ scratchConfig.paddedRepShape [scratchConfig.order [0 ]] - paddedStride;
387+
388+ // Linear layout function is split in two parts below:
389+ //
390+ // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
391+ // offset = regBase xor regIdx
392+ //
393+ // It is the same hack as what we've done in the emitIndices function to get
394+ // around performance issues on AMD GPUs
395+ auto getVecAddr = [&](LinearLayout &layout, Value ®Base,
396+ int regSlice) -> Value {
397+ auto regIdx = layout
398+ .apply ({{kRegister , regSlice},
399+ {kLane , 0 },
400+ {kWarp , 0 },
401+ {kBlock , 0 }})[0 ]
402+ .second ;
403+ Value offset = b.xor_ (regBase, b.i32_val (regIdx));
404+ if (paddedSize > 0 ) {
405+ assert (llvm::isPowerOf2_32 (paddedStride));
406+ assert (llvm::isPowerOf2_32 (paddedSize));
407+ auto rshiftVal = llvm::Log2_32 (paddedStride);
408+ auto lshiftVal = llvm::Log2_32 (paddedSize);
409+ offset = b.add (
410+ b.shl (b.lshr (offset, b.i32_val (rshiftVal)), b.i32_val (lshiftVal)),
411+ offset);
412+ }
413+ auto vecAddr = b.gep (sharedPtrTy, elemTy, smemBase, offset);
414+ vecAddr.setNoWrapFlags (mlir::LLVM::GEPNoWrapFlags::inbounds);
415+ return vecAddr;
416+ };
417+
418+ auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
419+ {{kRegister , b.i32_val (0 )},
420+ {kLane , laneId},
421+ {kWarp , warpId},
422+ {kBlock , b.i32_val (0 )}})[0 ]
423+ .second ;
424+ auto loadBase = applyLinearLayout (loc, rewriter, shmemLoadLayout,
425+ {{kRegister , b.i32_val (0 )},
426+ {kLane , laneId},
427+ {kWarp , warpId},
428+ {kBlock , b.i32_val (0 )}})[0 ]
429+ .second ;
430+ // register idx -> Value
431+ llvm::MapVector<int , Value> outVals;
432+ for (int i = 0 ; i < iterations; i++) {
433+ if (i != 0 )
434+ b.barrier ();
435+
436+ auto &inRegs = inRegsForIter[i];
437+ auto &outRegs = outRegsForIter[i];
438+
439+ // When using `stmatrix`, we can store `inVec` elements even if they are
440+ // not contiguous
441+ auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut ()
442+ : scratchConfig.inVec ;
443+ for (int j = 0 ; j < inVals.size () / iterations; j += inVec) {
444+ auto inRegSlice = inRegs[j];
445+ Value vecAddr = getVecAddr (shmemStoreLayout, storeBase, inRegSlice);
446+ SmallVector<Value> inValsVec;
447+ for (int k = 0 ; k < inVec; k++)
448+ inValsVec.push_back (inVals[inRegSlice + k]);
449+ Value valsVec = packLLVector (loc, inValsVec, rewriter);
450+ if (isStMatrix) {
451+ targetInfo.storeMatrixShared (rewriter, loc, vecAddr, valsVec);
452+ } else {
453+ targetInfo.storeDShared (rewriter, loc, vecAddr, std::nullopt , valsVec,
454+ /* pred=*/ b.true_val ());
455+ }
456+ }
457+
458+ b.barrier ();
459+
460+ for (int j = 0 ; j < outSize / iterations; j += scratchConfig.outVec ) {
461+ auto outRegSlice = outRegs[j];
462+ auto vecAddr = getVecAddr (shmemLoadLayout, loadBase, outRegSlice);
463+ Value valsVec =
464+ targetInfo.loadDShared (rewriter, loc, vecAddr, std::nullopt ,
465+ vec_ty (elemTy, scratchConfig.outVec ),
466+ /* pred=*/ b.true_val ());
467+ for (Value v : unpackLLVector (loc, valsVec, rewriter))
468+ outVals[outRegSlice++] = v;
469+ }
470+ }
471+
472+ SmallVector<Value> outValsVec;
473+ for (size_t i = 0 ; i < outVals.size (); i++)
474+ outValsVec.push_back (outVals[i]);
475+ return outValsVec;
476+ }
477+
478+ // Determine which registers are read/written in which iteration of the shmem
479+ // transfer specified by `layout`.
480+ SmallVector<SmallVector<int > /* registers*/ >
481+ collectRegsForIter (MLIRContext *ctx, const LinearLayout &layout) const {
482+ StringAttr kRegister = str_attr (" register" );
483+ StringAttr kLane = str_attr (" lane" );
484+ StringAttr kWarp = str_attr (" warp" );
485+ StringAttr kBlock = str_attr (" block" );
486+ StringAttr kIteration = str_attr (" iteration" );
487+
488+ // The choice of iteration should be determined only by the register. That
489+ // is, it should be correct to split the register dimension into iterations.
490+ assert (layout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
491+
492+ LinearLayout sublayout = layout.sublayout ({kRegister }, {kIteration });
493+ SmallVector<SmallVector<int >> ret (sublayout.getOutDimSize (kIteration ));
494+ for (int reg = 0 ; reg < sublayout.getInDimSize (kRegister ); reg++) {
495+ auto idx = sublayout.apply ({{kRegister , reg}});
496+ ret[idx.begin ()->second ].push_back (reg);
497+ }
498+ return ret;
499+ }
299500};
300501
301502} // namespace
0 commit comments