Skip to content

Commit 983aa34

Browse files
Merge OpenAI Triton commit e7fb841 (#5577)
This PR changes the Triton base from b116579 to e7fb841 (Nov 19). Pass rate: 95.42%
2 parents ba1d008 + c23297d commit 983aa34

File tree

71 files changed

+2382
-746
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+2382
-746
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
integration-tests-amd:
1414
runs-on: ${{ matrix.runner }}
1515
timeout-minutes: 45
16-
continue-on-error: ${{ matrix.runner[1] == 'gfx90a' || matrix.runner[0] == 'gfx950' }}
16+
continue-on-error: ${{ matrix.runner[1] == 'gfx90a' || matrix.runner[0] == 'amd-gfx950' }}
1717
strategy:
1818
matrix:
1919
runner: ${{ fromJson(inputs.matrix) }}
@@ -39,6 +39,7 @@ jobs:
3939
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
4040
--env-file /etc/podinfo/gha-gpu-isolation-settings
4141
--volume /home/runner/.triton:/github/home/.triton
42+
--volume /triton-data:/triton-data
4243
env:
4344
RUNNER_TYPE: ${{ matrix.runner[1] }}
4445
TRITON_BUILD_WITH_CCACHE: "true"
@@ -104,6 +105,10 @@ jobs:
104105
pip uninstall -y triton pytorch-triton-rocm
105106
106107
ccache --zero-stats
108+
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ]; then
109+
pip install --cache-dir /triton-data/pip-cache -r python/requirements.txt
110+
pip install --cache-dir /triton-data/pip-cache -r python/test-requirements.txt
111+
fi
107112
make dev-install
108113
- name: Print ccache stats
109114
run: ccache --print-stats

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ test-python: test-unit test-regression test-interpret test-proton
8282

8383
.PHONY: test-nogpu
8484
test-nogpu: test-lit test-cpp
85+
$(PYTEST) python/test/gluon/test_frontend.py
86+
$(PYTEST) python/test/unit/language/test_frontend.py
8587

8688
.PHONY: test
8789
test: test-lit test-cpp test-python

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,14 @@ bool isHostSideDescriptor(Value v);
200200
bool isKernel(FunctionOpInterface funcOp);
201201

202202
unsigned getBitwidth(RankedTensorType ty);
203+
204+
// If the value "anchor" is compared against a statically-computed bound, return
205+
// inclusive lower and upper bounds lb <= anchor <= ub. Depending on the
206+
// compariosn operator, one of the bounds is a computed one while the other is
207+
// derived from the data type of anchor.
208+
std::optional<ConstantIntRanges> getBoundFromCmpOp(arith::CmpIOp cmpOp,
209+
Value anchor);
210+
203211
} // namespace triton
204212
} // namespace mlir
205213

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
8888
This is analogue to tt.load except the data are copied to local memory pointed
8989
to by the memory descriptor instead of a distributed tensor. The rest of the
9090
operands are the same as tt.load.
91+
Contiguity is the maximum number of elements that can be loaded in a single vector with
92+
the given layout and mask.
93+
This allows op to use async_copy_global_to_local even if the alignment cannot be proven based on IR.
9194
}];
9295

9396
let arguments = (ins
@@ -97,7 +100,8 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
97100
Optional<TT_Type>:$other,
98101
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
99102
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
100-
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
103+
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
104+
DefaultValuedAttr<I32Attr, "1">:$contiguity
101105
);
102106

103107
let results = (outs TTG_AsyncToken:$token);

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3131
"MLIR_DISABLE_MULTITHREADING",
3232
"TRITON_DEFAULT_FP_FUSION",
3333
"TRITON_DISABLE_LINE_INFO",
34+
"TRITON_DUMP_MIR",
3435
"TRITON_ENABLE_LLVM_DEBUG",
3536
"TRITON_HIP_USE_ASYNC_COPY",
3637
"TRITON_HIP_USE_BLOCK_PINGPONG",

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,81 @@ unsigned tt::getBitwidth(RankedTensorType ty) {
128128
auto isPtr = isa<PointerType>(ty.getElementType());
129129
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
130130
}
131+
132+
std::optional<ConstantIntRanges> tt::getBoundFromCmpOp(arith::CmpIOp cmpOp,
133+
Value anchor) {
134+
bool isSigned = true;
135+
switch (cmpOp.getPredicate()) {
136+
case arith::CmpIPredicate::uge:
137+
case arith::CmpIPredicate::ugt:
138+
case arith::CmpIPredicate::ule:
139+
case arith::CmpIPredicate::ult:
140+
isSigned = false;
141+
default:
142+
break;
143+
}
144+
145+
bool anchorIsLhs = cmpOp.getLhs() == anchor;
146+
auto maybeConstantIntValue = getConstantIntValue(
147+
getAsOpFoldResult(anchorIsLhs ? cmpOp.getRhs() : cmpOp.getLhs()));
148+
if (auto constValue = maybeConstantIntValue) {
149+
unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType());
150+
assert(bitWidth > 0 && "expected non-zero bitwdith");
151+
APInt apVal = {bitWidth, static_cast<uint64_t>(*constValue), isSigned};
152+
APInt min, max;
153+
if (isSigned) {
154+
min = APInt::getSignedMinValue(bitWidth);
155+
if (llvm::isa_and_nonnull<mlir::triton::GetProgramIdOp,
156+
mlir::triton::GetNumProgramsOp>(
157+
anchor.getDefiningOp())) {
158+
min = APInt::getZero(bitWidth);
159+
} else
160+
min = APInt::getSignedMinValue(bitWidth);
161+
max = APInt::getSignedMaxValue(bitWidth);
162+
} else {
163+
min = APInt::getMinValue(bitWidth);
164+
max = APInt::getMaxValue(bitWidth);
165+
}
166+
167+
switch (cmpOp.getPredicate()) {
168+
case arith::CmpIPredicate::eq:
169+
return mlir::ConstantIntRanges::constant(apVal);
170+
case arith::CmpIPredicate::uge:
171+
case arith::CmpIPredicate::sge: {
172+
// K >= apVal implies K ∈ [apVal, max]
173+
if (anchorIsLhs)
174+
return mlir::ConstantIntRanges::range(apVal, max, isSigned);
175+
// apVal >= K implies K ∈ [min, apVal]
176+
return mlir::ConstantIntRanges::range(min, apVal, isSigned);
177+
}
178+
case arith::CmpIPredicate::ugt:
179+
case arith::CmpIPredicate::sgt: {
180+
// K > apVal implies K >= apVal + 1 implies K ∈ [apVal + 1, max]
181+
if (anchorIsLhs)
182+
return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned);
183+
// apVal > K implies apVal - 1 >= K implies K ∈ [min, apVal - 1]
184+
return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned);
185+
}
186+
case arith::CmpIPredicate::ule:
187+
case arith::CmpIPredicate::sle: {
188+
// K <= apVal implies K ∈ [min, apVal]
189+
if (anchorIsLhs)
190+
return mlir::ConstantIntRanges::range(min, apVal, isSigned);
191+
// apVal <= K implies K ∈ [apVal, max]
192+
return mlir::ConstantIntRanges::range(apVal, max, isSigned);
193+
}
194+
case arith::CmpIPredicate::ult:
195+
case arith::CmpIPredicate::slt: {
196+
// K < apVal implies K <= apVal -1 implies K ∈ [min, apVal - 1]
197+
if (anchorIsLhs)
198+
return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned);
199+
// apVal < K implies apVal + 1 <= K implies K ∈ [apVal + 1, max]
200+
return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned);
201+
}
202+
default:
203+
emitRemark(cmpOp.getLoc(), "unsupported cmp predicate for assumption");
204+
return {};
205+
}
206+
}
207+
return {};
208+
}

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
157157
<< "memorySpace must be SharedMemorySpace for shared encoding. "
158158
<< "Got " << memorySpace;
159159
}
160+
auto rank = cast<LayoutEncodingTrait>(enc).getRank();
161+
if (!(rank == shape.size() || rank == shape.size() - 1)) {
162+
return emitError() << "rank must be equal to or one less than "
163+
<< "the shape size. Got " << rank << " and "
164+
<< shape.size();
165+
}
160166
} else if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryScalesEncodingAttr>(
161167
encoding)) {
162168
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
@@ -177,12 +183,6 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
177183
// additional rules to verify.
178184
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(encoding)) {
179185
auto rank = enc.getRank();
180-
181-
if (rank != shape.size() && rank != shape.size() - 1) {
182-
return emitError() << "padding rank must be equal to or one less than "
183-
<< "the shape size when pipelining.";
184-
}
185-
186186
// Ensure linear component's outDims match the alloc size ignoring
187187
// pipelining dimension
188188
auto outDims = standardOutDimNames(ctx, rank);

lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "mlir/Support/LLVM.h"
22
#include "mlir/Transforms/Passes.h"
3+
#include "triton/Analysis/AxisInfo.h"
34
#include "triton/Analysis/Utility.h"
45
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
56
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -32,7 +33,11 @@ namespace gpu {
3233
// global data.
3334
struct ClipAsyncCopySizePerThread
3435
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
36+
ModuleAxisInfoAnalysis &axisInfoAnalysis;
3537
using OpRewritePattern::OpRewritePattern;
38+
ClipAsyncCopySizePerThread(ModuleAxisInfoAnalysis &axisInfoAnalysis,
39+
MLIRContext *context)
40+
: OpRewritePattern(context), axisInfoAnalysis(axisInfoAnalysis) {}
3641

3742
LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
3843
PatternRewriter &rewriter) const override {
@@ -94,12 +99,18 @@ struct ClipAsyncCopySizePerThread
9499
if (other)
95100
other = convertBlockLayout(other, newBlockEnc);
96101

102+
unsigned contiguity = axisInfoAnalysis.getContiguity(src);
103+
if (mask)
104+
contiguity = std::min<unsigned>(contiguity,
105+
axisInfoAnalysis.getMaskAlignment(mask));
106+
97107
rewriter.modifyOpInPlace(copyOp, [&]() {
98108
copyOp.getSrcMutable().assign(src);
99109
if (mask)
100110
copyOp.getMaskMutable().assign(mask);
101111
if (other)
102112
copyOp.getOtherMutable().assign(other);
113+
copyOp.setContiguity(contiguity);
103114
});
104115

105116
return success();
@@ -112,10 +123,11 @@ struct CoalesceAsyncCopyPass
112123

113124
void runOnOperation() override {
114125
ModuleOp m = getOperation();
126+
triton::ModuleAxisInfoAnalysis axisInfoAnalysis(m);
115127
MLIRContext *context = &getContext();
116128

117129
mlir::RewritePatternSet patterns(context);
118-
patterns.add<ClipAsyncCopySizePerThread>(context);
130+
patterns.add<ClipAsyncCopySizePerThread>(axisInfoAnalysis, context);
119131

120132
if (failed(applyPatternsGreedily(m, std::move(patterns))))
121133
signalPassFailure();

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
156156
}
157157

158158
void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
159-
Value insertIdx, Value extractIdx,
159+
Value insertIdx, Value extractIdx, int contiguity,
160160
CoarseSchedule &schedule) {
161161
OpBuilderForStage builder(loadOp.getLoc(), forOp, schedule);
162162
Value zero = arith::ConstantIntOp::create(builder, forOp.getLoc(), 0, 32);
@@ -176,7 +176,7 @@ void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
176176
Value view = createSingleBufferView(builder, alloc, insertIdx);
177177
Operation *copy = ttg::AsyncCopyGlobalToLocalOp::create(
178178
builder, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(),
179-
loadOp.getIsVolatile());
179+
loadOp.getIsVolatile(), contiguity);
180180
Operation *commit =
181181
ttg::AsyncCommitGroupOp::create(builder, copy->getResult(0));
182182

@@ -274,6 +274,7 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
274274

275275
struct AsyncLoad {
276276
int stageDiff;
277+
int contiguity = 1;
277278
Value alloc;
278279
Value barrier;
279280
Operation *waitOp;
@@ -459,6 +460,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
459460
}
460461
SharedEncodingTrait sharedEncoding;
461462
bool canUseAsyncCp = false;
463+
int contiguity = 1;
462464
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
463465
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
464466
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
@@ -478,6 +480,15 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
478480
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
479481

480482
canUseAsyncCp &= copyVecBytes >= 4;
483+
if (canUseAsyncCp) {
484+
auto loadOp = cast<tt::LoadOp>(op);
485+
auto ptr = loadOp.getPtr();
486+
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
487+
if (auto mask = loadOp.getMask())
488+
vec = std::min<unsigned>(vec,
489+
axisInfoAnalysis.getMaskAlignment(mask));
490+
contiguity = vec;
491+
}
481492
}
482493
if (canUseAsyncCp || isTMALoad(&op)) {
483494
if (loadRequiresAdditionalBuffer(&op)) {
@@ -486,6 +497,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
486497
}
487498
auto &asyncLoad = asyncLoads[&op];
488499
asyncLoad.stageDiff = stageDiff;
500+
asyncLoad.contiguity = contiguity;
489501
asyncLoad.sharedEncoding = sharedEncoding;
490502
} else if (stageDiff > 1) {
491503
// Distance-1 loads can in most cases be pipelined in registers without
@@ -589,7 +601,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
589601
auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff];
590602
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
591603
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
592-
schedule);
604+
asyncLoad.contiguity, schedule);
593605
hasAsyncLoads = true;
594606
} else if (auto loadOp = dyn_cast<tt::DescriptorLoadOp>(op)) {
595607
createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,

python/src/gluon_ir.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,8 +867,7 @@ void init_gluon_ir(py::module &&m) {
867867
})
868868
.def("create_async_tdm_copy_global_to_local",
869869
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
870-
Value result, Value barrier) {
871-
Value pred = self.create<arith::ConstantIntOp>(1, 1);
870+
Value result, Value pred, Value barrier) {
872871
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(
873872
descPtr, indices, result, pred, barrier);
874873
})

0 commit comments

Comments
 (0)