1010#include " triton/Analysis/Utility.h"
1111#include " triton/Conversion/MLIRTypes.h"
1212#include " triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
13+ #include " triton/Dialect/Triton/IR/Dialect.h"
1314#include " triton/Dialect/Triton/IR/Utility.h"
1415#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1516#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
@@ -367,8 +368,9 @@ inline bool isKernel(FunctionOpInterface funcOp) {
367368
368369inline Value getStackPointer (RewriterBase &rewriter,
369370 FunctionOpInterface funcOp) {
371+ // See NOTE: [Additional Function Arguments]
370372 if (!isKernel (funcOp)) {
371- return funcOp.getArgument (funcOp.getNumArguments () - 1 );
373+ return funcOp.getArgument (funcOp.getNumArguments () - 2 );
372374 }
373375
374376 auto mod = funcOp->getParentOfType <ModuleOp>();
@@ -377,6 +379,58 @@ inline Value getStackPointer(RewriterBase &rewriter,
377379 return rewriter.create <LLVM::AddressOfOp>(funcOp.getLoc (), globalBase);
378380}
379381
382+ inline Value getGlobalScratchPtr (Location loc, RewriterBase &rewriter,
383+ FunctionOpInterface funcOp,
384+ Value allocOffset = {}) {
385+ // See NOTE: [Additional Function Arguments]
386+ if (!isKernel (funcOp)) {
387+ // Base for this function
388+ auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
389+ if (!allocOffset) {
390+ return gmemBase;
391+ }
392+
393+ auto ptrTy = mlir::LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
394+ return gep (ptrTy, i8_ty, gmemBase, allocOffset);
395+ }
396+
397+ // Base for entire kernel
398+ auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
399+
400+ ModuleOp mod = funcOp.getOperation ()->getParentOfType <ModuleOp>();
401+ auto allocSizeAttr = mod.getOperation ()->getAttrOfType <mlir::IntegerAttr>(
402+ " triton_gpu.global_scratch_memory_size" );
403+ if (!allocSizeAttr) {
404+ return gmemBase;
405+ }
406+
407+ Value gridIdx[3 ];
408+ Value gridDim[2 ];
409+ for (int k = 0 ; k < 3 ; ++k) {
410+ gridIdx[k] = rewriter.create <GetProgramIdOp>(loc, k);
411+ }
412+ for (int k = 0 ; k < 2 ; ++k) {
413+ gridDim[k] = rewriter.create <GetNumProgramsOp>(loc, k);
414+ }
415+
416+ Value linearId = gridIdx[2 ];
417+ for (int k = 0 ; k < 2 ; ++k) {
418+ linearId = add (gridIdx[1 - k], mul (linearId, gridDim[1 - k]));
419+ }
420+
421+ auto allocSize = allocSizeAttr.getValue ().getZExtValue ();
422+
423+ Value offset = mul (linearId, i32_val (allocSize));
424+ if (allocOffset) {
425+ offset = add (offset, allocOffset);
426+ }
427+
428+ auto *ctx = rewriter.getContext ();
429+ auto res =
430+ gep (mlir::LLVM::LLVMPointerType::get (ctx, 1 ), i8_ty, gmemBase, offset);
431+ return res;
432+ }
433+
380434inline Value getSharedMemoryBase (Location loc, RewriterBase &rewriter,
381435 const TargetInfoBase &target, Operation *op) {
382436 auto ptrTy = LLVM::LLVMPointerType::get (rewriter.getContext (),
@@ -466,15 +520,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
466520 auto sizePerThread = blockedLayout.getSizePerThread ();
467521 auto threadsPerWarp = blockedLayout.getThreadsPerWarp ();
468522 auto warpsPerCTA = blockedLayout.getWarpsPerCTA ();
469- auto order = blockedLayout.getOrder ();
523+ auto threadOrder = blockedLayout.getThreadOrder ();
524+ auto warpOrder = blockedLayout.getWarpOrder ();
470525 auto shapePerCTA = triton::gpu::getShapePerCTA (blockedLayout, shape);
471526 unsigned rank = shape.size ();
472527
473528 // delinearize threadId to get the base index
474529 SmallVector<Value> multiDimWarpId =
475- delinearize (rewriter, loc, warpId, warpsPerCTA, order );
530+ delinearize (rewriter, loc, warpId, warpsPerCTA, warpOrder );
476531 SmallVector<Value> multiDimThreadId =
477- delinearize (rewriter, loc, laneId, threadsPerWarp, order );
532+ delinearize (rewriter, loc, laneId, threadsPerWarp, threadOrder );
478533
479534 SmallVector<Value> multiDimBase (rank);
480535 for (unsigned k = 0 ; k < rank; ++k) {
@@ -543,122 +598,6 @@ emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
543598// Mma layout indices
544599// -----------------------------------------------------------------------
545600
546- inline SmallVector<Value>
547- emitBaseIndexWithinCTAForMmaLayoutV1 (Location loc, RewriterBase &rewriter,
548- const NvidiaMmaEncodingAttr &mmaLayout,
549- RankedTensorType type) {
550- auto shape = type.getShape ();
551- auto wpt = mmaLayout.getWarpsPerCTA ();
552- static constexpr std::array<int , 3 > fpw{{2 , 2 , 1 }};
553- auto [isARow, isBRow, isAVec4, isBVec4, _] =
554- mmaLayout.decodeVoltaLayoutStates ();
555-
556- Value thread = getThreadId (rewriter, loc);
557- auto *ctx = thread.getContext ();
558- Value _1 = i32_val (1 );
559- Value _2 = i32_val (2 );
560- Value _4 = i32_val (4 );
561- Value _16 = i32_val (16 );
562- Value _32 = i32_val (32 );
563- Value _fpw0 = i32_val (fpw[0 ]);
564- Value _fpw1 = i32_val (fpw[1 ]);
565-
566- // A info
567- auto aRep = mmaLayout.getMMAv1Rep (0 );
568- auto aSpw = mmaLayout.getMMAv1ShapePerWarp (0 );
569- // B info
570- auto bSpw = mmaLayout.getMMAv1ShapePerWarp (1 );
571- auto bRep = mmaLayout.getMMAv1Rep (1 );
572-
573- SmallVector<int , 2 > rep ({aRep[0 ], bRep[1 ]});
574- SmallVector<int , 2 > spw ({aSpw[0 ], bSpw[1 ]});
575- SmallVector<unsigned , 2 > shapePerCTA ({spw[0 ] * wpt[0 ], spw[1 ] * wpt[1 ]});
576-
577- Value lane = urem (thread, _32);
578- Value warp = udiv (thread, _32);
579-
580- Value warp0 = urem (warp, i32_val (wpt[0 ]));
581- Value warp12 = udiv (warp, i32_val (wpt[0 ]));
582- Value warp1 = urem (warp12, i32_val (wpt[1 ]));
583-
584- // warp offset
585- Value offWarpM = mul (warp0, i32_val (spw[0 ]));
586- Value offWarpN = mul (warp1, i32_val (spw[1 ]));
587- // quad offset
588- Value offQuadM = mul (udiv (and_ (lane, _16), _4), _fpw0);
589- Value offQuadN = mul (udiv (and_ (lane, _16), _4), _fpw1);
590- // pair offset
591- Value offPairM = udiv (urem (lane, _16), _4);
592- offPairM = urem (offPairM, _fpw0);
593- offPairM = mul (offPairM, _4);
594- Value offPairN = udiv (urem (lane, _16), _4);
595- offPairN = udiv (offPairN, _fpw0);
596- offPairN = urem (offPairN, _fpw1);
597- offPairN = mul (offPairN, _4);
598- offPairM = mul (offPairM, i32_val (rep[0 ] / 2 ));
599- offQuadM = mul (offQuadM, i32_val (rep[0 ] / 2 ));
600- offPairN = mul (offPairN, i32_val (rep[1 ] / 2 ));
601- offQuadN = mul (offQuadN, i32_val (rep[1 ] / 2 ));
602- // quad pair offset
603- Value offLaneM = add (offPairM, offQuadM);
604- Value offLaneN = add (offPairN, offQuadN);
605- // a, b offset
606- Value offsetAM = add (offWarpM, offLaneM);
607- Value offsetBN = add (offWarpN, offLaneN);
608- // m indices
609- Value offsetCM = add (and_ (lane, _1), offsetAM);
610- // n indices
611- Value offsetCN = add ((and_ (lane, _2)), (add (offWarpN, offPairN)));
612- return {offsetCM, offsetCN};
613- }
614-
615- inline SmallVector<SmallVector<unsigned >>
616- emitOffsetForMmaLayoutV1 (const NvidiaMmaEncodingAttr &mmaLayout,
617- RankedTensorType type) {
618- auto shape = type.getShape ();
619-
620- auto [isARow, isBRow, isAVec4, isBVec4, _] =
621- mmaLayout.decodeVoltaLayoutStates ();
622-
623- // TODO: seems like the pattern below to get `rep`/`spw` appears quite often
624- // A info
625- auto aRep = mmaLayout.getMMAv1Rep (0 );
626- auto aSpw = mmaLayout.getMMAv1ShapePerWarp (0 );
627- // B info
628- auto bSpw = mmaLayout.getMMAv1ShapePerWarp (1 );
629- auto bRep = mmaLayout.getMMAv1Rep (1 );
630-
631- auto wpt = mmaLayout.getWarpsPerCTA ();
632- static constexpr std::array<int , 3 > fpw{{2 , 2 , 1 }};
633- SmallVector<int , 2 > rep ({aRep[0 ], bRep[1 ]});
634- SmallVector<int , 2 > spw ({aSpw[0 ], bSpw[1 ]});
635- SmallVector<unsigned , 2 > shapePerCTA ({spw[0 ] * wpt[0 ], spw[1 ] * wpt[1 ]});
636-
637- SmallVector<unsigned > idxM;
638- for (unsigned m = 0 ; m < shape[0 ]; m += shapePerCTA[0 ])
639- for (unsigned mm = 0 ; mm < rep[0 ]; ++mm)
640- idxM.push_back (m + mm * 2 );
641-
642- SmallVector<unsigned > idxN;
643- for (int n = 0 ; n < shape[1 ]; n += shapePerCTA[1 ]) {
644- for (int nn = 0 ; nn < rep[1 ]; ++nn) {
645- idxN.push_back (n + nn / 2 * 4 + (nn % 2 ) * 2 * fpw[1 ] * rep[1 ]);
646- idxN.push_back (n + nn / 2 * 4 + (nn % 2 ) * 2 * fpw[1 ] * rep[1 ] + 1 );
647- }
648- }
649-
650- SmallVector<SmallVector<unsigned >> ret;
651- for (unsigned x1 : idxN) { // N
652- for (unsigned x0 : idxM) { // M
653- SmallVector<unsigned > idx (2 );
654- idx[0 ] = x0; // M
655- idx[1 ] = x1; // N
656- ret.push_back (std::move (idx));
657- }
658- }
659- return ret;
660- }
661-
662601inline SmallVector<SmallVector<unsigned >>
663602emitOffsetForMmaLayoutV2 (const NvidiaMmaEncodingAttr &mmaLayout,
664603 RankedTensorType type) {
@@ -1124,9 +1063,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
11241063 result = emitBaseIndexWithinCTAForBlockedLayout (loc, rewriter,
11251064 blockedLayout, type);
11261065 } else if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1127- if (mmaLayout.isVolta ())
1128- result =
1129- emitBaseIndexWithinCTAForMmaLayoutV1 (loc, rewriter, mmaLayout, type);
11301066 if (mmaLayout.isAmpere () || mmaLayout.isHopper ())
11311067 result = emitBaseIndexWithinCTAForMmaLayoutV2V3 (loc, rewriter, mmaLayout,
11321068 type);
@@ -1481,18 +1417,6 @@ inline Value packLLVector(Location loc, ValueRange vals,
14811417 return vec;
14821418}
14831419
1484- inline bool isLayoutMmaV1 (Attribute layout) {
1485- bool isMmaV1 = false ;
1486- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1487- isMmaV1 = mmaLayout.isVolta ();
1488- }
1489- if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
1490- isMmaV1 = isa<NvidiaMmaEncodingAttr>(sliceLayout.getParent ()) &&
1491- cast<NvidiaMmaEncodingAttr>(sliceLayout.getParent ()).isVolta ();
1492- }
1493- return isMmaV1;
1494- }
1495-
14961420} // namespace mlir
14971421
14981422#endif
0 commit comments