Skip to content

Commit 494d897

Browse files
authored
Merge branch 'main' into issue2662
2 parents f8b0851 + e4fa38e commit 494d897

File tree

114 files changed

+3902
-3080
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+3902
-3080
lines changed

.github/workflows/integration-tests.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ jobs:
279279
ctest -j32
280280
- name: Run Proton tests
281281
run: |
282-
cd third_party/proton
283-
python3 -m pytest -s test
282+
cd third_party/proton/test
283+
python3 -m pytest -s .
284+
cd ..
284285
- # If we're on branch `main`, save the ccache Triton compilation artifacts
285286
# to the cache so they can be used by other (non-main) CI runs.
286287
#
@@ -425,8 +426,9 @@ jobs:
425426
python3 -m pytest -s -n 8 ./test_cast_matmul.py
426427
- name: Run Proton tests
427428
run: |
428-
cd third_party/proton
429-
python3 -m pytest -s test
429+
cd third_party/proton/test
430+
python3 -m pytest -s .
431+
cd ..
430432
- name: Run C++ unittests
431433
run: |
432434
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ jobs:
319319
- &run-proton-tests-step
320320
name: Run Proton tests
321321
run: |
322-
cd third_party/proton
323-
python3 -m pytest -s test
322+
cd third_party/proton/test
323+
python3 -m pytest -s .
324+
cd ..
324325

325326
# If we're on branch `main`, save the ccache Triton compilation artifacts
326327
# to the cache so they can be used by other (non-main) CI runs.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ python/*.whl
1111
python/triton/_C/*.pyd
1212
python/triton/_C/*.so
1313
python/triton/_C/*.dylib
14+
python/triton/_C/*.pdb
15+
python/triton/_C/*.exe
16+
python/triton/_C/*.ilk
1417

1518
benchmarks/dist
1619
benchmarks/*.egg-info/

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6464
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
6565
mlir::triton::intel::registerTritonRaiseBlockPointer();
6666
mlir::triton::registerAllocateSharedMemoryPass();
67+
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();
6768
mlir::triton::registerConvertTritonGPUToLLVMPass();
6869
mlir::triton::registerConvertNVGPUToLLVMPass();
6970
mlir::triton::registerDecomposeUnsupportedNVIDIAConversions();

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ReduceOpHelper {
6666
// The shape of the shared memory space needed for the reduction.
6767
SmallVector<unsigned> getScratchRepShape();
6868

69-
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
69+
SmallVector<unsigned> getOrderWithAxisAtBeginning();
7070

7171
unsigned getScratchSizeInBytes();
7272

include/triton/Conversion/TritonGPUToLLVM/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ namespace triton {
2020
namespace gpu {
2121
std::unique_ptr<OperationPass<ModuleOp>> createAllocateSharedMemoryPass();
2222

23+
std::unique_ptr<Pass> createTritonGPUGlobalScratchAllocationPass();
24+
2325
} // namespace gpu
2426

2527
#define GEN_PASS_REGISTRATION

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,18 @@ def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> {
1515
let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()";
1616
}
1717

18+
def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> {
19+
let summary = "Assign global scratch memory allocation";
20+
21+
let description = [{
22+
Decide on global scratch space memory allocation and assign attributes to each allocation.
23+
}];
24+
25+
let constructor = "mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass()";
26+
27+
let dependentDialects = [
28+
"mlir::triton::gpu::TritonGPUDialect"
29+
];
30+
}
31+
1832
#endif

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 59 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Analysis/Utility.h"
1111
#include "triton/Conversion/MLIRTypes.h"
1212
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
13+
#include "triton/Dialect/Triton/IR/Dialect.h"
1314
#include "triton/Dialect/Triton/IR/Utility.h"
1415
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
@@ -367,8 +368,9 @@ inline bool isKernel(FunctionOpInterface funcOp) {
367368

368369
inline Value getStackPointer(RewriterBase &rewriter,
369370
FunctionOpInterface funcOp) {
371+
// See NOTE: [Additional Function Arguments]
370372
if (!isKernel(funcOp)) {
371-
return funcOp.getArgument(funcOp.getNumArguments() - 1);
373+
return funcOp.getArgument(funcOp.getNumArguments() - 2);
372374
}
373375

374376
auto mod = funcOp->getParentOfType<ModuleOp>();
@@ -377,6 +379,58 @@ inline Value getStackPointer(RewriterBase &rewriter,
377379
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
378380
}
379381

382+
inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
383+
FunctionOpInterface funcOp,
384+
Value allocOffset = {}) {
385+
// See NOTE: [Additional Function Arguments]
386+
if (!isKernel(funcOp)) {
387+
// Base for this function
388+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
389+
if (!allocOffset) {
390+
return gmemBase;
391+
}
392+
393+
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
394+
return gep(ptrTy, i8_ty, gmemBase, allocOffset);
395+
}
396+
397+
// Base for entire kernel
398+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
399+
400+
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
401+
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
402+
"triton_gpu.global_scratch_memory_size");
403+
if (!allocSizeAttr) {
404+
return gmemBase;
405+
}
406+
407+
Value gridIdx[3];
408+
Value gridDim[2];
409+
for (int k = 0; k < 3; ++k) {
410+
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
411+
}
412+
for (int k = 0; k < 2; ++k) {
413+
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
414+
}
415+
416+
Value linearId = gridIdx[2];
417+
for (int k = 0; k < 2; ++k) {
418+
linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k]));
419+
}
420+
421+
auto allocSize = allocSizeAttr.getValue().getZExtValue();
422+
423+
Value offset = mul(linearId, i32_val(allocSize));
424+
if (allocOffset) {
425+
offset = add(offset, allocOffset);
426+
}
427+
428+
auto *ctx = rewriter.getContext();
429+
auto res =
430+
gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
431+
return res;
432+
}
433+
380434
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
381435
const TargetInfoBase &target, Operation *op) {
382436
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
@@ -466,15 +520,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
466520
auto sizePerThread = blockedLayout.getSizePerThread();
467521
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
468522
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
469-
auto order = blockedLayout.getOrder();
523+
auto threadOrder = blockedLayout.getThreadOrder();
524+
auto warpOrder = blockedLayout.getWarpOrder();
470525
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
471526
unsigned rank = shape.size();
472527

473528
// delinearize threadId to get the base index
474529
SmallVector<Value> multiDimWarpId =
475-
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
530+
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
476531
SmallVector<Value> multiDimThreadId =
477-
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
532+
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
478533

479534
SmallVector<Value> multiDimBase(rank);
480535
for (unsigned k = 0; k < rank; ++k) {
@@ -543,122 +598,6 @@ emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
543598
// Mma layout indices
544599
// -----------------------------------------------------------------------
545600

546-
inline SmallVector<Value>
547-
emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter,
548-
const NvidiaMmaEncodingAttr &mmaLayout,
549-
RankedTensorType type) {
550-
auto shape = type.getShape();
551-
auto wpt = mmaLayout.getWarpsPerCTA();
552-
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
553-
auto [isARow, isBRow, isAVec4, isBVec4, _] =
554-
mmaLayout.decodeVoltaLayoutStates();
555-
556-
Value thread = getThreadId(rewriter, loc);
557-
auto *ctx = thread.getContext();
558-
Value _1 = i32_val(1);
559-
Value _2 = i32_val(2);
560-
Value _4 = i32_val(4);
561-
Value _16 = i32_val(16);
562-
Value _32 = i32_val(32);
563-
Value _fpw0 = i32_val(fpw[0]);
564-
Value _fpw1 = i32_val(fpw[1]);
565-
566-
// A info
567-
auto aRep = mmaLayout.getMMAv1Rep(0);
568-
auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0);
569-
// B info
570-
auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1);
571-
auto bRep = mmaLayout.getMMAv1Rep(1);
572-
573-
SmallVector<int, 2> rep({aRep[0], bRep[1]});
574-
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
575-
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
576-
577-
Value lane = urem(thread, _32);
578-
Value warp = udiv(thread, _32);
579-
580-
Value warp0 = urem(warp, i32_val(wpt[0]));
581-
Value warp12 = udiv(warp, i32_val(wpt[0]));
582-
Value warp1 = urem(warp12, i32_val(wpt[1]));
583-
584-
// warp offset
585-
Value offWarpM = mul(warp0, i32_val(spw[0]));
586-
Value offWarpN = mul(warp1, i32_val(spw[1]));
587-
// quad offset
588-
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
589-
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
590-
// pair offset
591-
Value offPairM = udiv(urem(lane, _16), _4);
592-
offPairM = urem(offPairM, _fpw0);
593-
offPairM = mul(offPairM, _4);
594-
Value offPairN = udiv(urem(lane, _16), _4);
595-
offPairN = udiv(offPairN, _fpw0);
596-
offPairN = urem(offPairN, _fpw1);
597-
offPairN = mul(offPairN, _4);
598-
offPairM = mul(offPairM, i32_val(rep[0] / 2));
599-
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
600-
offPairN = mul(offPairN, i32_val(rep[1] / 2));
601-
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
602-
// quad pair offset
603-
Value offLaneM = add(offPairM, offQuadM);
604-
Value offLaneN = add(offPairN, offQuadN);
605-
// a, b offset
606-
Value offsetAM = add(offWarpM, offLaneM);
607-
Value offsetBN = add(offWarpN, offLaneN);
608-
// m indices
609-
Value offsetCM = add(and_(lane, _1), offsetAM);
610-
// n indices
611-
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
612-
return {offsetCM, offsetCN};
613-
}
614-
615-
inline SmallVector<SmallVector<unsigned>>
616-
emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout,
617-
RankedTensorType type) {
618-
auto shape = type.getShape();
619-
620-
auto [isARow, isBRow, isAVec4, isBVec4, _] =
621-
mmaLayout.decodeVoltaLayoutStates();
622-
623-
// TODO: seems like the pattern below to get `rep`/`spw` appears quite often
624-
// A info
625-
auto aRep = mmaLayout.getMMAv1Rep(0);
626-
auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0);
627-
// B info
628-
auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1);
629-
auto bRep = mmaLayout.getMMAv1Rep(1);
630-
631-
auto wpt = mmaLayout.getWarpsPerCTA();
632-
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
633-
SmallVector<int, 2> rep({aRep[0], bRep[1]});
634-
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
635-
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
636-
637-
SmallVector<unsigned> idxM;
638-
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
639-
for (unsigned mm = 0; mm < rep[0]; ++mm)
640-
idxM.push_back(m + mm * 2);
641-
642-
SmallVector<unsigned> idxN;
643-
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
644-
for (int nn = 0; nn < rep[1]; ++nn) {
645-
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]);
646-
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1);
647-
}
648-
}
649-
650-
SmallVector<SmallVector<unsigned>> ret;
651-
for (unsigned x1 : idxN) { // N
652-
for (unsigned x0 : idxM) { // M
653-
SmallVector<unsigned> idx(2);
654-
idx[0] = x0; // M
655-
idx[1] = x1; // N
656-
ret.push_back(std::move(idx));
657-
}
658-
}
659-
return ret;
660-
}
661-
662601
inline SmallVector<SmallVector<unsigned>>
663602
emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout,
664603
RankedTensorType type) {
@@ -1124,9 +1063,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
11241063
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
11251064
blockedLayout, type);
11261065
} else if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1127-
if (mmaLayout.isVolta())
1128-
result =
1129-
emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type);
11301066
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
11311067
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout,
11321068
type);
@@ -1481,18 +1417,6 @@ inline Value packLLVector(Location loc, ValueRange vals,
14811417
return vec;
14821418
}
14831419

1484-
inline bool isLayoutMmaV1(Attribute layout) {
1485-
bool isMmaV1 = false;
1486-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
1487-
isMmaV1 = mmaLayout.isVolta();
1488-
}
1489-
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
1490-
isMmaV1 = isa<NvidiaMmaEncodingAttr>(sliceLayout.getParent()) &&
1491-
cast<NvidiaMmaEncodingAttr>(sliceLayout.getParent()).isVolta();
1492-
}
1493-
return isMmaV1;
1494-
}
1495-
14961420
} // namespace mlir
14971421

14981422
#endif

0 commit comments

Comments
 (0)