Skip to content

Commit 8fdcd12

Browse files
Merge commit '618ec403b6d59b6cd0d45e6170940a7ce16533d5'
2 parents 6299f1a + 618ec40 commit 8fdcd12

File tree

34 files changed

+2019
-1377
lines changed

34 files changed

+2019
-1377
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,11 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
299299
// Verify a memory allocation operation.
300300
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
301301

302-
std::optional<SetVector<int>> getPartitionIds(Operation *op);
303-
std::optional<int> getNumOutputPartitionIds(Operation *op);
304-
std::optional<SetVector<int>> getOutputPartitionIds(Operation *op, int idx);
302+
SetVector<int> getPartitionIds(Operation *op);
303+
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
304+
SetVector<int> getPartitionIds(OpOperand *use);
305+
bool hasPartition(Operation *op);
306+
305307
} // namespace mlir::triton::gpu
306308

307309
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,13 @@ class PartitionSet {
103103
// Utility to be used when the op is known to belong to one partition
104104
Partition *getPartition(Operation *op);
105105

106-
// Check if the operation belongs to all partitions
107-
bool isInRootPartition(Operation *op);
108-
109106
private:
110107
// WarpSpecialization tag
111108
int tag;
112109
// Partitions are numbered [0, N).
113110
SmallVector<std::unique_ptr<Partition>> partitions;
114111
};
115112

116-
bool hasPartition(Operation *op);
117-
118113
// Annotate the op with the partition index or indices, and add the op
119114
// to the partitions it belongs to.
120115
void setPartition(Operation *op, Partition *partition);
@@ -125,7 +120,6 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions);
125120
void setPartition(Operation *op, const SetVector<int> &partitionIds);
126121
void setPartitionOutputs(Operation *op,
127122
ArrayRef<SetVector<int>> partitionOutputsIds);
128-
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
129123

130124
} // namespace mlir::triton::gpu
131125

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,25 +121,6 @@ def TritonGPUAutomaticWarpSpecialization : Pass<"tritongpu-automatic-warp-specia
121121
];
122122
}
123123

124-
def TritonGPURewritePartitionDependencies : Pass<"tritongpu-rewrite-partition-dependencies", "mlir::ModuleOp"> {
125-
let summary = "test pass for rewriting partition dependencies";
126-
127-
let description = [{
128-
The `tritongpu-rewrite-partition-dependencies` pass analyzes the partitions
129-
assigned to a loop and their SSA dependencies. It rewrites the dependencies
130-
to be passed through shared memory, applying multi-buffering according to
131-
the assigned stages of the partitions.
132-
}];
133-
134-
let dependentDialects = [
135-
"mlir::triton::gpu::TritonGPUDialect",
136-
"mlir::scf::SCFDialect",
137-
"mlir::arith::ArithDialect",
138-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
139-
"mlir::triton::nvws::NVWSDialect"
140-
];
141-
}
142-
143124
def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"> {
144125
let summary = "split scheduled loops into `ttg.warp_specialize`";
145126

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3814,12 +3814,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
38143814
return childOp.emitOpError("does not have expected attribute ")
38153815
<< kPartitionAttrName
38163816
<< " which is expected for ops whose parent has partitions";
3817-
auto ids = *getPartitionIds(&childOp);
3817+
auto ids = getPartitionIds(&childOp);
38183818
expectedIds.insert(ids.begin(), ids.end());
38193819
}
38203820
}
38213821
}
3822-
auto partitionIds = *getPartitionIds(op);
3822+
auto partitionIds = getPartitionIds(op);
38233823
for (auto id : expectedIds) {
38243824
if (!partitionIds.contains(id)) {
38253825
return op->emitOpError("partition ids in attr ")
@@ -3854,13 +3854,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
38543854
return op->emitOpError("does not have expected attribute ")
38553855
<< kPartitionAttrName << " which is expected for ops with attr "
38563856
<< kPartitionOutputsAttrName;
3857-
auto partitionIds = *getPartitionIds(op);
3857+
auto partitionIds = getPartitionIds(op);
38583858

38593859
SetVector<int> outputPartitionIdsUnion;
3860-
for (auto idx = 0; idx < *getNumOutputPartitionIds(op); idx++) {
3861-
auto outputPartitionIds = getOutputPartitionIds(op, idx);
3862-
for (auto partitionId : *outputPartitionIds)
3863-
outputPartitionIdsUnion.insert(partitionId);
3860+
for (auto outputPartitionIds : getPartitionOutputs(op)) {
3861+
outputPartitionIdsUnion.insert(outputPartitionIds.begin(),
3862+
outputPartitionIds.end());
38643863
}
38653864
if (!std::all_of(outputPartitionIdsUnion.begin(),
38663865
outputPartitionIdsUnion.end(),
@@ -3972,15 +3971,8 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
39723971
return dst;
39733972
}
39743973

3975-
std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
3976-
if (!op) {
3977-
return std::nullopt;
3978-
}
3974+
SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
39793975
auto attrs = op->getAttr(kPartitionAttrName);
3980-
if (!attrs) {
3981-
return std::nullopt;
3982-
}
3983-
39843976
SmallVector<int> partitionIds;
39853977
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
39863978
partitionIds.push_back(id);
@@ -3989,33 +3981,31 @@ std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
39893981
return SetVector<int>(partitionIds.begin(), partitionIds.end());
39903982
}
39913983

3992-
std::optional<int> triton::gpu::getNumOutputPartitionIds(Operation *op) {
3993-
if (!op) {
3994-
return std::nullopt;
3984+
SmallVector<SetVector<int>, 4> triton::gpu::getPartitionOutputs(Operation *op) {
3985+
SmallVector<SetVector<int>, 4> partitionOutputsIds;
3986+
if (op->getNumResults() == 0) {
3987+
return partitionOutputsIds;
39953988
}
3996-
auto attr = op->getAttr(kPartitionOutputsAttrName);
3997-
if (!attr) {
3998-
return std::nullopt;
3989+
auto arrayAttr = cast<ArrayAttr>(op->getAttr(kPartitionOutputsAttrName));
3990+
for (auto attr : arrayAttr) {
3991+
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
3992+
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
39993993
}
4000-
return cast<ArrayAttr>(attr).size();
3994+
return partitionOutputsIds;
40013995
}
40023996

4003-
std::optional<SetVector<int>> triton::gpu::getOutputPartitionIds(Operation *op,
4004-
int idx) {
4005-
if (!op) {
4006-
return std::nullopt;
4007-
}
4008-
auto attr = op->getAttr(kPartitionOutputsAttrName);
4009-
if (!attr) {
4010-
return std::nullopt;
3997+
SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) {
3998+
auto owner = use->getOwner();
3999+
if (isa<scf::YieldOp>(owner)) {
4000+
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
4001+
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
4002+
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
4003+
return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp);
4004+
} else {
4005+
return getPartitionIds(owner);
40114006
}
4012-
assert(idx < cast<ArrayAttr>(attr).size());
4013-
auto attrs = cast<ArrayAttr>(attr)[idx];
4007+
}
40144008

4015-
SmallVector<int> partitionIds;
4016-
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
4017-
partitionIds.push_back(id);
4018-
}
4019-
std::sort(partitionIds.begin(), partitionIds.end());
4020-
return SetVector<int>(partitionIds.begin(), partitionIds.end());
4009+
bool triton::gpu::hasPartition(Operation *op) {
4010+
return op && op->hasAttr(kPartitionAttrName);
40214011
}

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ add_triton_library(TritonGPUTransforms
3434
WarpSpecialization/PartitionBuilder.cpp
3535
WarpSpecialization/PartitionLoops.cpp
3636
WarpSpecialization/PartitionScheduling.cpp
37-
WarpSpecialization/RewritePartitionDependencies.cpp
3837

3938
DEPENDS
4039
TritonGPUTransformsIncGen

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ bool ttng::isOperandPipelineableBase(
6969
return true;
7070
}
7171
auto localAllocSrc = localAlloc.getSrc().getDefiningOp();
72-
if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(
73-
localAllocSrc)) {
72+
if (!isa_and_nonnull<tt::LoadOp, tt::DescriptorLoadOp,
73+
tt::DescriptorGatherOp>(localAllocSrc)) {
7474
return false;
7575
}
7676
foundDef = localAllocSrc;

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ void AutomaticWarpSpecialization::runOnOperation() {
3737
pm.addPass(createTritonGPUPartitionScheduling());
3838
pm.addPass(createNVWSInsertAref());
3939
pm.addPass(createNVWSInsertTmemAref());
40-
pm.addPass(createTritonGPURewritePartitionDependencies());
4140
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4241
// FIXME: Re-enable integer range analysis once it is fixed.
4342
// pm.addPass(arith::createIntRangeOptimizationsPass());

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ using namespace triton::gpu;
1515
//===----------------------------------------------------------------------===//
1616

1717
bool Partition::hasOp(Operation *op) const {
18-
auto partitionIds = getPartitionIds(op);
19-
if (!partitionIds) {
18+
if (!hasPartition(op)) {
2019
return false;
2120
}
22-
return partitionIds->contains(getIndex());
21+
auto partitionIds = getPartitionIds(op);
22+
return partitionIds.contains(getIndex());
2323
}
2424

2525
void Partition::iterateInputs(scf::ForOp loop,
@@ -28,7 +28,9 @@ void Partition::iterateInputs(scf::ForOp loop,
2828
visitNestedOperands(op, [&](OpOperand &operand) {
2929
// Ignore implicit captures.
3030
Value value = operand.get();
31-
auto partitionIds = getPartitionIds(value.getDefiningOp());
31+
std::optional<SetVector<int>> partitionIds;
32+
if (hasPartition(value.getDefiningOp()))
33+
partitionIds = getPartitionIds(value.getDefiningOp());
3234
if (value.getParentBlock() != loop.getBody())
3335
return;
3436
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -59,7 +61,9 @@ void Partition::iterateOutputs(
5961
if (!owner) {
6062
continue;
6163
}
62-
auto partitionIds = getPartitionIds(owner);
64+
std::optional<SetVector<int>> partitionIds;
65+
if (hasPartition(owner))
66+
partitionIds = getPartitionIds(owner);
6367
if (isa<scf::YieldOp>(owner)) {
6468
// This value is used in a subsequent iteration.
6569
callback(owner, use);
@@ -123,13 +127,8 @@ const Partition *PartitionSet::getPartition(unsigned idx) const {
123127

124128
Partition *PartitionSet::getPartition(Operation *op) {
125129
auto id = getPartitionIds(op);
126-
assert(id && id->size() == 1);
127-
return getPartition((*id)[0]);
128-
}
129-
130-
bool PartitionSet::isInRootPartition(Operation *op) {
131-
auto partitionIds = getPartitionIds(op);
132-
return !partitionIds || partitionIds->size() == getNumPartitions();
130+
assert(id.size() == 1);
131+
return getPartition(id[0]);
133132
}
134133

135134
FailureOr<PartitionSet> PartitionSet::fromLoop(scf::ForOp loop) {
@@ -155,13 +154,11 @@ FailureOr<PartitionSet> PartitionSet::fromLoop(scf::ForOp loop) {
155154
}
156155

157156
for (Operation &op : loop.getBody()->without_terminator()) {
158-
if (auto attrs = getPartitionIds(&op)) {
159-
for (auto idx : *attrs) {
160-
if (idx < 0 || idx >= result.partitions.size())
161-
return mlir::emitError(op.getLoc(), "invalid partition index ")
162-
<< idx;
163-
result.partitions[idx]->addOp(&op);
164-
}
157+
auto attrs = getPartitionIds(&op);
158+
for (auto idx : attrs) {
159+
if (idx < 0 || idx >= result.partitions.size())
160+
return mlir::emitError(op.getLoc(), "invalid partition index ") << idx;
161+
result.partitions[idx]->addOp(&op);
165162
}
166163
}
167164

@@ -232,23 +229,4 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions) {
232229
setPartition(op, partitionIds);
233230
}
234231

235-
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op) {
236-
if (!op) {
237-
return {};
238-
}
239-
240-
auto attrs = op->getAttr(kPartitionOutputsAttrName);
241-
if (!attrs) {
242-
return {};
243-
}
244-
SmallVector<SetVector<int>, 4> partitionOutputsIds;
245-
for (auto attr : cast<ArrayAttr>(attrs)) {
246-
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
247-
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
248-
}
249-
return partitionOutputsIds;
250-
}
251-
252-
bool hasPartition(Operation *op) { return getPartitionIds(op) != std::nullopt; }
253-
254232
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)