Skip to content

Commit 8922df9

Browse files
authored
[TritonGPU] Add a separate pass for optimizing partition num warps (triton-lang#6323)
PartitionLoops no longer changes the number of warps of the partitions -- they are always left at the default number. A separate pass runs after the important DCE to analyze the IR and determine the number of warps needed. As before, if there are no tensor computations in the partition, the number of warps is set to 1. However, the pass will now also look at tensor computations and shrink the number of warps. It does this by doing a very rough estimate of the register usage of the partition and comparing it against the pool of available registers given the number of warps. If it can reduce the number of warps, the algorithm iterates to fixed point. This is important, e.g., for TMA gather or scatter ops that can still use tensor ops, but generally small 1D tensors. This still assumes uniform register allocation per warp. In the future, we can make the pass much more sophisticated by allowing nonuniform register allocation with setmaxnreg. The PR implements a pretty major hack to relayout IR by wiping them out and rerunning `convert-triton-to-tritongpu` which sadly breaks layering of TTGIR vs TTIR.
1 parent 708b5f1 commit 8922df9

File tree

17 files changed

+576
-96
lines changed

17 files changed

+576
-96
lines changed

include/triton/Conversion/TritonToTritonGPU/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
3636
"number of ctas in a cga">,
3737
Option<"target", "target",
3838
"std::string", /*default*/"\"\"",
39-
"the GPU target, e.g., cuda:80, hip:gfx942">
39+
"the GPU target, e.g., cuda:80, hip:gfx942">,
40+
Option<"enableSourceRemat", "enable-source-remat",
41+
"bool", /*default*/"false",
42+
"enable trivial source rematerialization">,
4043
];
4144
}
4245

include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
1818
// Create the pass with numWarps set explicitly.
1919
std::unique_ptr<OperationPass<ModuleOp>>
2020
createConvertTritonToTritonGPUPass(const std::string &target, int numWarps,
21-
int threadsPerWarp = 32, int numCTAs = 1);
21+
int threadsPerWarp = 32, int numCTAs = 1,
22+
bool enableSourceRemat = false);
2223

2324
} // namespace triton
2425
} // namespace mlir

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"
155155
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
156156
}
157157

158+
def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps", "mlir::ModuleOp"> {
159+
let summary = "optimize the number of warps assigned to partitions";
160+
161+
let description = [{
162+
The `tritongpu-optimize-partition-warps` pass will analyze the partitions
163+
of `ttg.warp_specialize` ops and attempts to reduce the number of warps
164+
assigned to them and optimize the register usage of the partitions.
165+
}];
166+
}
167+
158168
def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "mlir::ModuleOp"> {
159169
let summary = "load MMA specialization";
160170

include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace mlir {
1414
class TritonGPUTypeConverter : public TypeConverter {
1515
public:
1616
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
17-
int numCTAs);
17+
int numCTAs, bool enableSourceRemat);
1818
int getNumWarps() const { return numWarps; }
1919
int getThreadsPerWarp() const { return threadsPerWarp; }
2020
int getNumCTAs() const { return numCTAs; }

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ struct ReduceOpConversion
136136
uniqueOffsets.insert({offsets[i], i});
137137
}
138138

139-
unsigned srcElems = getTotalElemsPerThread(operandType);
140139
auto *combineOp = &op.getCombineOp();
141140
auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo,
142141
helper.getSrcLayout(), operandType, true);

lib/Conversion/TritonToTritonGPU/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ add_triton_library(TritonToTritonGPU
1212
TritonIR
1313
ProtonIR
1414
TritonGPUIR
15-
TritonGPUTransforms
1615
)

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "triton/Dialect/Triton/IR/Dialect.h"
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
12+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1213

1314
using namespace mlir;
1415
using namespace mlir::triton::gpu;
@@ -18,7 +19,8 @@ using namespace mlir::triton::gpu;
1819
//
1920
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
2021
int numWarps, int threadsPerWarp,
21-
int numCTAs)
22+
int numCTAs,
23+
bool enableSourceRemat)
2224
: context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp),
2325
numCTAs(numCTAs) {
2426
addConversion([](Type type) { return type; });
@@ -55,28 +57,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5557
//
5658
// This will be called when (newArgType != origArgType)
5759
// This will create newArg, and map(origArg, newArg)
58-
addArgumentMaterialization([&](OpBuilder &builder,
59-
RankedTensorType tensorType, ValueRange inputs,
60-
Location loc) -> Value {
60+
addArgumentMaterialization([](OpBuilder &builder, RankedTensorType tensorType,
61+
ValueRange inputs, Location loc) -> Value {
6162
llvm_unreachable("Argument rematerialization should not happen in Triton "
6263
"-> TritonGPU conversion");
6364
return {};
6465
});
6566

6667
// If the origValue still has live user(s), use this to
6768
// convert origValue to newValue
68-
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
69+
addSourceMaterialization([=](OpBuilder &builder, RankedTensorType tensorType,
6970
ValueRange inputs, Location loc) -> Value {
70-
llvm_unreachable("Source rematerialization should not happen in Triton -> "
71-
"TritonGPU Conversion");
72-
return {};
71+
assert(enableSourceRemat && "Source rematerialization should not happen in "
72+
"Triton -> TritonGPU Conversion");
73+
return builder.create<UnrealizedConversionCastOp>(loc, tensorType, inputs)
74+
.getResult(0);
7375
});
7476

7577
// This will be called when (desiredType != newOperandType)
7678
// where, desiredType = typeConverter->convertType(origType)
7779
// NOTE: only for remapped values.
78-
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
79-
ValueRange inputs, Location loc) {
80+
addTargetMaterialization([](OpBuilder &builder, RankedTensorType tensorType,
81+
ValueRange inputs, Location loc) {
8082
auto cast =
8183
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
8284
return cast.getResult();
@@ -98,7 +100,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
98100

99101
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
100102
triton::TritonDialect, cf::ControlFlowDialect,
101-
scf::SCFDialect, ub::UBDialect>(
103+
scf::SCFDialect, ub::UBDialect,
104+
triton::nvidia_gpu::TritonNvidiaGPUDialect>(
102105
[&](Operation *op) {
103106
bool hasLegalRegions = true;
104107
for (auto &region : op->getRegions()) {

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
13+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1314
#include "llvm/ADT/APSInt.h"
1415
#include <numeric>
1516

@@ -431,50 +432,37 @@ static RankedTensorType getNewIndicesType(RankedTensorType type,
431432
newEncoding);
432433
}
433434

434-
struct TritonDescriptorGatherPattern
435-
: public OpConversionPattern<triton::DescriptorGatherOp> {
436-
using OpConversionPattern::OpConversionPattern;
437-
438-
LogicalResult
439-
matchAndRewrite(triton::DescriptorGatherOp op, OpAdaptor adaptor,
440-
ConversionPatternRewriter &rewriter) const override {
441-
auto numThreads = lookupThreadsPerWarp(rewriter);
442-
auto numWarps = lookupNumWarps(op);
443-
RankedTensorType newType = getNewIndicesType(
444-
cast<RankedTensorType>(adaptor.getXOffsets().getType()), numThreads,
445-
numWarps);
446-
if (!newType)
447-
return failure();
448-
449-
Value newInd = rewriter.create<ConvertLayoutOp>(op.getLoc(), newType,
450-
adaptor.getXOffsets());
451-
rewriter.replaceOpWithNewOp<triton::DescriptorGatherOp>(
452-
op, getTypeConverter()->convertType(op.getType()), adaptor.getDesc(),
453-
newInd, adaptor.getYOffset());
454-
return success();
455-
}
456-
};
435+
// Function for converting any gather or scatter op that requires a specific
436+
// index layout. This also handles converting result types if there are any.
437+
static LogicalResult convertGatherScatterOp(Operation *op, OpOperand &indices,
438+
ConversionPatternRewriter &b) {
439+
auto type = cast<RankedTensorType>(indices.get().getType());
440+
RankedTensorType newType =
441+
getNewIndicesType(type, lookupThreadsPerWarp(b), lookupNumWarps(op));
442+
if (!newType)
443+
return failure();
444+
Value index = b.create<ConvertLayoutOp>(op->getLoc(), newType, indices.get());
445+
indices.set(index);
446+
return success();
447+
}
457448

458-
struct TritonDescriptorScatterPattern
459-
: public OpConversionPattern<triton::DescriptorScatterOp> {
460-
using OpConversionPattern::OpConversionPattern;
449+
template <typename OpT>
450+
struct GatherScatterOpPattern : public OpConversionPattern<OpT> {
451+
using OpConversionPattern<OpT>::OpConversionPattern;
461452

462453
LogicalResult
463-
matchAndRewrite(triton::DescriptorScatterOp op, OpAdaptor adaptor,
454+
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
464455
ConversionPatternRewriter &rewriter) const override {
465-
auto numThreads = lookupThreadsPerWarp(rewriter);
466-
auto numWarps = lookupNumWarps(op);
467-
RankedTensorType newType = getNewIndicesType(
468-
cast<RankedTensorType>(adaptor.getXOffsets().getType()), numThreads,
469-
numWarps);
470-
if (!newType)
471-
return failure();
472-
473-
Value newInd = rewriter.create<ConvertLayoutOp>(op.getLoc(), newType,
474-
adaptor.getXOffsets());
475-
rewriter.replaceOpWithNewOp<triton::DescriptorScatterOp>(
476-
op, adaptor.getDesc(), newInd, adaptor.getYOffset(), adaptor.getSrc());
477-
return success();
456+
LogicalResult result = success();
457+
rewriter.modifyOpInPlace(op, [&] {
458+
for (auto [operand, value] :
459+
llvm::zip(op->getOpOperands(), adaptor.getOperands()))
460+
operand.set(value);
461+
for (OpResult result : op->getOpResults())
462+
result.setType(this->typeConverter->convertType(result.getType()));
463+
result = convertGatherScatterOp(op, op.getXOffsetsMutable(), rewriter);
464+
});
465+
return result;
478466
}
479467
};
480468

@@ -619,10 +607,13 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
619607
GenericOpPattern<triton::ReduceReturnOp>, TritonScanPattern,
620608
GenericOpPattern<triton::ScanReturnOp>,
621609
GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
622-
TritonTransPattern, TritonDotPattern, TritonDescriptorGatherPattern,
623-
TritonDescriptorScatterPattern, GenericOpPattern<triton::LoadOp>,
624-
GenericOpPattern<triton::StoreOp>, GenericOpPattern<triton::HistogramOp>,
625-
GenericOpPattern<triton::GatherOp>,
610+
TritonTransPattern, TritonDotPattern,
611+
GatherScatterOpPattern<DescriptorGatherOp>,
612+
GatherScatterOpPattern<DescriptorScatterOp>,
613+
GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAGatherOp>,
614+
GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAScatterOp>,
615+
GenericOpPattern<triton::LoadOp>, GenericOpPattern<triton::StoreOp>,
616+
GenericOpPattern<triton::HistogramOp>, GenericOpPattern<triton::GatherOp>,
626617
GenericOpPattern<triton::ExternElementwiseOp>,
627618
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
628619
GenericOpPattern<triton::AtomicCASOp>,
@@ -840,11 +831,13 @@ class ConvertTritonToTritonGPU
840831
ConvertTritonToTritonGPU() = default;
841832
// constructor with some parameters set explicitly.
842833
ConvertTritonToTritonGPU(const std::string &target, int numWarps,
843-
int threadsPerWarp, int numCTAs) {
834+
int threadsPerWarp, int numCTAs,
835+
bool enableSourceRemat) {
844836
this->numWarps = numWarps;
845837
this->threadsPerWarp = threadsPerWarp;
846838
this->numCTAs = numCTAs;
847839
this->target = target;
840+
this->enableSourceRemat = enableSourceRemat;
848841
}
849842

850843
void runOnOperation() override {
@@ -859,7 +852,7 @@ class ConvertTritonToTritonGPU
859852
ModuleOp mod = getOperation();
860853
// type converter
861854
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
862-
numCTAs);
855+
numCTAs, enableSourceRemat);
863856
TritonGPUConversionTarget target(*context, typeConverter);
864857
// rewrite patterns
865858
RewritePatternSet patterns(context);
@@ -898,9 +891,10 @@ std::unique_ptr<OperationPass<ModuleOp>>
898891
mlir::triton::createConvertTritonToTritonGPUPass(const std::string &target,
899892
int numWarps,
900893
int threadsPerWarp,
901-
int numCTAs) {
902-
return std::make_unique<::ConvertTritonToTritonGPU>(target, numWarps,
903-
threadsPerWarp, numCTAs);
894+
int numCTAs,
895+
bool enableSourceRemat) {
896+
return std::make_unique<::ConvertTritonToTritonGPU>(
897+
target, numWarps, threadsPerWarp, numCTAs, enableSourceRemat);
904898
}
905899

906900
std::unique_ptr<OperationPass<ModuleOp>>

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_triton_library(TritonGPUTransforms
3434
Utility.cpp
3535
WarpSpecialization/AutomaticWarpSpecialization.cpp
3636
WarpSpecialization/LoadMMASpecialization.cpp
37+
WarpSpecialization/OptimizePartitionWarps.cpp
3738
WarpSpecialization/PartitionLoops.cpp
3839
WarpSpecialization/RewritePartitionDependencies.cpp
3940

@@ -47,5 +48,6 @@ add_triton_library(TritonGPUTransforms
4748
TritonIR
4849
TritonGPUIR
4950
TritonNvidiaGPUIR
51+
TritonToTritonGPU
5052
MLIRTransformUtils
5153
)

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,9 @@ void AutomaticWarpSpecialization::runOnOperation() {
5151
WarpSpecializeOp::getCanonicalizationPatterns(patterns, &getContext());
5252
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
5353
return signalPassFailure();
54+
55+
pm.clear();
56+
pm.addPass(createTritonGPUOptimizePartitionWarps());
57+
if (failed(runPipeline(pm, getOperation())))
58+
return signalPassFailure();
5459
}

0 commit comments

Comments
 (0)