Skip to content

Commit 4bbf937

Browse files
Merge commit '4480f86045cfa036618dba6c86664715e211b21e'
2 parents d34a1f3 + 4480f86 commit 4bbf937

File tree

68 files changed

+1390
-1557
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1390
-1557
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6464
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
6565
mlir::triton::intel::registerTritonRaiseBlockPointer();
6666
mlir::triton::registerAllocateSharedMemoryPass();
67+
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();
6768
mlir::triton::registerConvertTritonGPUToLLVMPass();
6869
mlir::triton::registerConvertNVGPUToLLVMPass();
6970
mlir::triton::registerDecomposeUnsupportedNVIDIAConversions();

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ReduceOpHelper {
6666
// The shape of the shared memory space needed for the reduction.
6767
SmallVector<unsigned> getScratchRepShape();
6868

69-
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
69+
SmallVector<unsigned> getOrderWithAxisAtBeginning();
7070

7171
unsigned getScratchSizeInBytes();
7272

include/triton/Conversion/TritonGPUToLLVM/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ namespace triton {
2020
namespace gpu {
2121
std::unique_ptr<OperationPass<ModuleOp>> createAllocateSharedMemoryPass();
2222

23+
std::unique_ptr<Pass> createTritonGPUGlobalScratchAllocationPass();
24+
2325
} // namespace gpu
2426

2527
#define GEN_PASS_REGISTRATION

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,18 @@ def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> {
1515
let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()";
1616
}
1717

18+
def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> {
19+
let summary = "Assign global scratch memory allocation";
20+
21+
let description = [{
22+
Decide on global scratch space memory allocation and assign attributes to each allocation.
23+
}];
24+
25+
let constructor = "mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass()";
26+
27+
let dependentDialects = [
28+
"mlir::triton::gpu::TritonGPUDialect"
29+
];
30+
}
31+
1832
#endif

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 59 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Analysis/Utility.h"
1111
#include "triton/Conversion/MLIRTypes.h"
1212
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
13+
#include "triton/Dialect/Triton/IR/Dialect.h"
1314
#include "triton/Dialect/Triton/IR/Utility.h"
1415
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
@@ -367,8 +368,9 @@ inline bool isKernel(FunctionOpInterface funcOp) {
367368

368369
inline Value getStackPointer(RewriterBase &rewriter,
369370
FunctionOpInterface funcOp) {
371+
// See NOTE: [Additional Function Arguments]
370372
if (!isKernel(funcOp)) {
371-
return funcOp.getArgument(funcOp.getNumArguments() - 1);
373+
return funcOp.getArgument(funcOp.getNumArguments() - 2);
372374
}
373375

374376
auto mod = funcOp->getParentOfType<ModuleOp>();
@@ -377,6 +379,58 @@ inline Value getStackPointer(RewriterBase &rewriter,
377379
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
378380
}
379381

382+
inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
383+
FunctionOpInterface funcOp,
384+
Value allocOffset = {}) {
385+
// See NOTE: [Additional Function Arguments]
386+
if (!isKernel(funcOp)) {
387+
// Base for this function
388+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
389+
if (!allocOffset) {
390+
return gmemBase;
391+
}
392+
393+
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
394+
return gep(ptrTy, i8_ty, gmemBase, allocOffset);
395+
}
396+
397+
// Base for entire kernel
398+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
399+
400+
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
401+
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
402+
"triton_gpu.global_scratch_memory_size");
403+
if (!allocSizeAttr) {
404+
return gmemBase;
405+
}
406+
407+
Value gridIdx[3];
408+
Value gridDim[2];
409+
for (int k = 0; k < 3; ++k) {
410+
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
411+
}
412+
for (int k = 0; k < 2; ++k) {
413+
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
414+
}
415+
416+
Value linearId = gridIdx[2];
417+
for (int k = 0; k < 2; ++k) {
418+
linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k]));
419+
}
420+
421+
auto allocSize = allocSizeAttr.getValue().getZExtValue();
422+
423+
Value offset = mul(linearId, i32_val(allocSize));
424+
if (allocOffset) {
425+
offset = add(offset, allocOffset);
426+
}
427+
428+
auto *ctx = rewriter.getContext();
429+
auto res =
430+
gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
431+
return res;
432+
}
433+
380434
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
381435
const TargetInfoBase &target, Operation *op) {
382436
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
@@ -466,15 +520,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
466520
auto sizePerThread = blockedLayout.getSizePerThread();
467521
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
468522
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
469-
auto order = blockedLayout.getOrder();
523+
auto threadOrder = blockedLayout.getThreadOrder();
524+
auto warpOrder = blockedLayout.getWarpOrder();
470525
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
471526
unsigned rank = shape.size();
472527

473528
// delinearize threadId to get the base index
474529
SmallVector<Value> multiDimWarpId =
475-
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
530+
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
476531
SmallVector<Value> multiDimThreadId =
477-
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
532+
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
478533

479534
SmallVector<Value> multiDimBase(rank);
480535
for (unsigned k = 0; k < rank; ++k) {
@@ -543,122 +598,6 @@ emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
543598
// Mma layout indices
544599
// -----------------------------------------------------------------------
545600

546-
inline SmallVector<Value>
547-
emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter,
548-
const NvidiaMmaEncodingAttr &mmaLayout,
549-
RankedTensorType type) {
550-
auto shape = type.getShape();
551-
auto wpt = mmaLayout.getWarpsPerCTA();
552-
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
553-
auto [isARow, isBRow, isAVec4, isBVec4, _] =
554-
mmaLayout.decodeVoltaLayoutStates();
555-
556-
Value thread = getThreadId(rewriter, loc);
557-
auto *ctx = thread.getContext();
558-
Value _1 = i32_val(1);
559-
Value _2 = i32_val(2);
560-
Value _4 = i32_val(4);
561-
Value _16 = i32_val(16);
562-
Value _32 = i32_val(32);
563-
Value _fpw0 = i32_val(fpw[0]);
564-
Value _fpw1 = i32_val(fpw[1]);
565-
566-
// A info
567-
auto aRep = mmaLayout.getMMAv1Rep(0);
568-
auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0);
569-
// B info
570-
auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1);
571-
auto bRep = mmaLayout.getMMAv1Rep(1);
572-
573-
SmallVector<int, 2> rep({aRep[0], bRep[1]});
574-
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
575-
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
576-
577-
Value lane = urem(thread, _32);
578-
Value warp = udiv(thread, _32);
579-
580-
Value warp0 = urem(warp, i32_val(wpt[0]));
581-
Value warp12 = udiv(warp, i32_val(wpt[0]));
582-
Value warp1 = urem(warp12, i32_val(wpt[1]));
583-
584-
// warp offset
585-
Value offWarpM = mul(warp0, i32_val(spw[0]));
586-
Value offWarpN = mul(warp1, i32_val(spw[1]));
587-
// quad offset
588-
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
589-
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
590-
// pair offset
591-
Value offPairM = udiv(urem(lane, _16), _4);
592-
offPairM = urem(offPairM, _fpw0);
593-
offPairM = mul(offPairM, _4);
594-
Value offPairN = udiv(urem(lane, _16), _4);
595-
offPairN = udiv(offPairN, _fpw0);
596-
offPairN = urem(offPairN, _fpw1);
597-
offPairN = mul(offPairN, _4);
598-
offPairM = mul(offPairM, i32_val(rep[0] / 2));
599-
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
600-
offPairN = mul(offPairN, i32_val(rep[1] / 2));
601-
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
602-
// quad pair offset
603-
Value offLaneM = add(offPairM, offQuadM);
604-
Value offLaneN = add(offPairN, offQuadN);
605-
// a, b offset
606-
Value offsetAM = add(offWarpM, offLaneM);
607-
Value offsetBN = add(offWarpN, offLaneN);
608-
// m indices
609-
Value offsetCM = add(and_(lane, _1), offsetAM);
610-
// n indices
611-
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
612-
return {offsetCM, offsetCN};
613-
}
614-
615-
inline SmallVector<SmallVector<unsigned>>
616-
emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout,
617-
RankedTensorType type) {
618-
auto shape = type.getShape();
619-
620-
auto [isARow, isBRow, isAVec4, isBVec4, _] =
621-
mmaLayout.decodeVoltaLayoutStates();
622-
623-
// TODO: seems like the pattern below to get `rep`/`spw` appears quite often
624-
// A info
625-
auto aRep = mmaLayout.getMMAv1Rep(0);
626-
auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0);
627-
// B info
628-
auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1);
629-
auto bRep = mmaLayout.getMMAv1Rep(1);
630-
631-
auto wpt = mmaLayout.getWarpsPerCTA();
632-
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
633-
SmallVector<int, 2> rep({aRep[0], bRep[1]});
634-
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
635-
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
636-
637-
SmallVector<unsigned> idxM;
638-
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
639-
for (unsigned mm = 0; mm < rep[0]; ++mm)
640-
idxM.push_back(m + mm * 2);
641-
642-
SmallVector<unsigned> idxN;
643-
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
644-
for (int nn = 0; nn < rep[1]; ++nn) {
645-
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]);
646-
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1);
647-
}
648-
}
649-
650-
SmallVector<SmallVector<unsigned>> ret;
651-
for (unsigned x1 : idxN) { // N
652-
for (unsigned x0 : idxM) { // M
653-
SmallVector<unsigned> idx(2);
654-
idx[0] = x0; // M
655-
idx[1] = x1; // N
656-
ret.push_back(std::move(idx));
657-
}
658-
}
659-
return ret;
660-
}
661-
662601
inline SmallVector<SmallVector<unsigned>>
663602
emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout,
664603
RankedTensorType type) {
@@ -1124,9 +1063,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
11241063
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
11251064
blockedLayout, type);
11261065
} else if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1127-
if (mmaLayout.isVolta())
1128-
result =
1129-
emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type);
11301066
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
11311067
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout,
11321068
type);
@@ -1481,18 +1417,6 @@ inline Value packLLVector(Location loc, ValueRange vals,
14811417
return vec;
14821418
}
14831419

1484-
inline bool isLayoutMmaV1(Attribute layout) {
1485-
bool isMmaV1 = false;
1486-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1487-
isMmaV1 = mmaLayout.isVolta();
1488-
}
1489-
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
1490-
isMmaV1 = isa<NvidiaMmaEncodingAttr>(sliceLayout.getParent()) &&
1491-
cast<NvidiaMmaEncodingAttr>(sliceLayout.getParent()).isVolta();
1492-
}
1493-
return isMmaV1;
1494-
}
1495-
14961420
} // namespace mlir
14971421

14981422
#endif

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,41 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
947947
];
948948
}
949949

950+
//
951+
// Make Tensor Descriptor Op
952+
//
953+
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
954+
[Pure,
955+
SameVariadicOperandSize]> {
956+
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
957+
958+
let description = [{
959+
`tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size,
960+
and returns a descriptor object which can be used to load/store from the tensor in global memory.
961+
}];
962+
963+
let arguments = (ins
964+
TT_Ptr:$base,
965+
Variadic<I32>:$shape,
966+
Variadic<I64>:$strides,
967+
DenseI32ArrayAttr:$tensorShape
968+
);
969+
970+
// TODO(peterbell10): define a custom IR type to represent descriptors
971+
let results = (outs TT_Ptr:$result);
972+
973+
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
974+
975+
let builders = [
976+
OpBuilder<(ins
977+
"Value":$base,
978+
"ValueRange":$shape,
979+
"ValueRange":$strides,
980+
"ArrayRef<int32_t>":$tensorShape
981+
)>
982+
];
983+
}
984+
950985
// The following ops, including `call`, `func`, and `return` are copied and modified from
951986
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
952987
// We could revert it back once MLIR has a better inliner interface.

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,8 @@ SmallVector<unsigned>
7676
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
7777

7878
// Returns the dimensions of the tensor from minor (fast-varying) to
79-
// major (slow-varying). For blocked, mma, and dotOperand layouts,
80-
// though the elements are in registers, the order refers to memory
81-
// layout of the original tensor in global memory.
79+
// major (slow-varying). For distributed layouts, this represents
80+
// the order of the elements within a thread.
8281
// For shared Layout, the order refers to which dimension of the original tensor
8382
// is contiguous in shared memory.
8483
SmallVector<unsigned> getOrder(Attribute layout);

0 commit comments

Comments
 (0)