Skip to content

Commit 4fc492c

Browse files
authored
Merge OpenAI Triton commit 22b1a44 (#5229)
This PR change the Triton base from 407b8a3 to 22b1a44 (Sep 25). Pass rate: 97.26%->92.66%
2 parents 05a7055 + 9cb4b9b commit 4fc492c

File tree

53 files changed

+3446
-3558
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3446
-3558
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ jobs:
173173
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
174174
- name: Run regression tests
175175
run: |
176-
# Reenable test_functional_regression.py once it's fixed
177-
cd python/test/regression
178-
python3 -m pytest -s -n 8 ./test_cast_matmul.py
176+
make test-regression
179177
- name: Run microbenchmark tests
180178
run: |
181179
python3 python/test/microbenchmark/launch_overhead.py

include/triton/Analysis/AxisInfo.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,13 @@ class AxisInfo {
161161

162162
std::optional<int64_t> getConstantValue() const { return constantValue; }
163163

164-
template <class T>
165-
static void
166-
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
167-
DimVectorT *divisibility, DimVectorT *constancy);
164+
static void initPessimisticStateFromFunc(int argNumber,
165+
FunctionOpInterface funcOp,
166+
DimVectorT *contiguity,
167+
DimVectorT *divisibility,
168+
DimVectorT *constancy);
169+
170+
static void initDimVectorFromHint(Attribute attr, DimVectorT *vec);
168171

169172
bool operator==(const AxisInfo &other) const {
170173
return contiguity == other.contiguity &&

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,24 +176,6 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir:
176176
}];
177177
}
178178

179-
def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "mlir::ModuleOp"> {
180-
let summary = "load MMA specialization";
181-
182-
let description = [{
183-
The `tritongpu-load-mma-specialization` pass looks for matmul loops in the
184-
module and attempts to create a partition schedule, separating async loads
185-
and async MMAs into separate partitions.
186-
}];
187-
188-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
189-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
190-
191-
let options = [
192-
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
193-
"number of pipeline stages">
194-
];
195-
}
196-
197179
def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
198180
let summary = "3xTF32 trick";
199181

lib/Analysis/AxisInfo.cpp

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,18 +1084,12 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10841084
auto newContiguity = curr.getContiguity();
10851085
auto newDivisibility = curr.getDivisibility();
10861086
auto newConstancy = curr.getConstancy();
1087-
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
1088-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1089-
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
1090-
}
1091-
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
1092-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1093-
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
1094-
}
1095-
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
1096-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1097-
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
1098-
}
1087+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.contiguity"),
1088+
&newContiguity);
1089+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
1090+
&newDivisibility);
1091+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.constancy"),
1092+
&newConstancy);
10991093
curr = AxisInfo(newContiguity, newDivisibility, newConstancy,
11001094
curr.getConstantValue());
11011095
// join all lattice elements
@@ -1140,25 +1134,29 @@ void AxisInfoAnalysis::visitWarpSpecializeExplicitCaptures(
11401134

11411135
} // anonymous namespace
11421136

1143-
template <class T>
1144-
void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
1137+
void AxisInfo::initPessimisticStateFromFunc(int argNumber,
1138+
FunctionOpInterface funcOp,
11451139
DimVectorT *contiguity,
11461140
DimVectorT *divisibility,
11471141
DimVectorT *constancy) {
1148-
// liast of attributes that we care about
1142+
// list of attributes that we care about
11491143
SmallVector<std::pair<DimVectorT *, std::string>> retVecs;
11501144
retVecs.push_back({contiguity, "tt.contiguity"});
11511145
retVecs.push_back({divisibility, "tt.divisibility"});
11521146
retVecs.push_back({constancy, "tt.constancy"});
11531147
// initialize attributes one by one
11541148
for (auto [vec, attrName] : retVecs) {
11551149
Attribute attr = funcOp.getArgAttr(argNumber, attrName);
1156-
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
1157-
*vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue());
1158-
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
1159-
auto vals = dense_attr.getValues<int>();
1160-
*vec = DimVectorT(vals.begin(), vals.end());
1161-
}
1150+
AxisInfo::initDimVectorFromHint(attr, vec);
1151+
}
1152+
}
1153+
1154+
void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
1155+
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
1156+
*vec = DimVectorT(1, int_attr.getValue().getZExtValue());
1157+
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
1158+
auto vals = dense_attr.getValues<int>();
1159+
*vec = DimVectorT(vals.begin(), vals.end());
11621160
}
11631161
}
11641162

@@ -1202,18 +1200,12 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12021200
}
12031201
// Other operations are conservatively initialized with the lowest possible
12041202
// divisibility, contiguity, and constancy unless they have specified.
1205-
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
1206-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1207-
knownDivisibility = DimVectorT(vals.begin(), vals.end());
1208-
}
1209-
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
1210-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1211-
knownContiguity = DimVectorT(vals.begin(), vals.end());
1212-
}
1213-
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
1214-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1215-
knownConstancy = DimVectorT(vals.begin(), vals.end());
1216-
}
1203+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
1204+
&knownDivisibility);
1205+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.contiguity"),
1206+
&knownContiguity);
1207+
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.constancy"),
1208+
&knownConstancy);
12171209
}
12181210

12191211
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ add_triton_library(TritonGPUTransforms
2929
Utility.cpp
3030
LayoutPropagationUtility.cpp
3131
WarpSpecialization/AutomaticWarpSpecialization.cpp
32-
WarpSpecialization/LoadMMASpecialization.cpp
3332
WarpSpecialization/Partition.cpp
3433
WarpSpecialization/OptimizePartitionWarps.cpp
3534
WarpSpecialization/PartitionBuilder.cpp

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
3838
pm.addPass(createNVWSInsertAref());
39-
#if 0
4039
pm.addPass(createNVWSInsertTmemAref());
41-
#else
42-
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
43-
#endif
4440
pm.addPass(createTritonGPURewritePartitionDependencies());
4541
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4642
// FIXME: Re-enable integer range analysis once it is fixed.
@@ -50,19 +46,6 @@ void AutomaticWarpSpecialization::runOnOperation() {
5046
pm.addPass(createNVWSLowerAref({numStages}));
5147
pm.addPass(createTritonGPUPartitionLoops());
5248
pm.addPass(createNVWSLowerWarpGroup());
53-
if (failed(runPipeline(pm, getOperation())))
54-
return signalPassFailure();
55-
56-
// Cleanup code generated by warp specialization.
57-
RewritePatternSet patterns(&getContext());
58-
populateForOpDeadArgumentElimination(patterns);
59-
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
60-
scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
61-
WarpSpecializeOp::getCanonicalizationPatterns(patterns, &getContext());
62-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
63-
return signalPassFailure();
64-
65-
pm.clear();
6649
pm.addPass(createTritonGPUOptimizePartitionWarps());
6750
pm.addPass(createTritonGPUScheduleLoops());
6851
if (failed(runPipeline(pm, getOperation())))

0 commit comments

Comments
 (0)