Skip to content

Commit faeb1eb

Browse files
authored
[WS] remove std::optional from get-partition-ids (#8629)
- remove `std::optional` from partitionIds interface (cc @acollins3) - remove stray variable form tmem struct
1 parent 167bdc8 commit faeb1eb

File tree

11 files changed

+181
-225
lines changed

11 files changed

+181
-225
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,11 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
299299
// Verify a memory allocation operation.
300300
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
301301

302-
std::optional<SetVector<int>> getPartitionIds(Operation *op);
303-
std::optional<int> getNumOutputPartitionIds(Operation *op);
304-
std::optional<SetVector<int>> getOutputPartitionIds(Operation *op, int idx);
302+
SetVector<int> getPartitionIds(Operation *op);
303+
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
304+
SetVector<int> getPartitionIds(OpOperand *use);
305+
bool hasPartition(Operation *op);
306+
305307
} // namespace mlir::triton::gpu
306308

307309
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,13 @@ class PartitionSet {
103103
// Utility to be used when the op is known to belong to one partition
104104
Partition *getPartition(Operation *op);
105105

106-
// Check if the operation belongs to all partitions
107-
bool isInRootPartition(Operation *op);
108-
109106
private:
110107
// WarpSpecialization tag
111108
int tag;
112109
// Partitions are numbered [0, N).
113110
SmallVector<std::unique_ptr<Partition>> partitions;
114111
};
115112

116-
bool hasPartition(Operation *op);
117-
118113
// Annotate the op with the partition index or indices, and add the op
119114
// to the partitions it belongs to.
120115
void setPartition(Operation *op, Partition *partition);
@@ -125,8 +120,6 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions);
125120
void setPartition(Operation *op, const SetVector<int> &partitionIds);
126121
void setPartitionOutputs(Operation *op,
127122
ArrayRef<SetVector<int>> partitionOutputsIds);
128-
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
129-
SetVector<int> getPartitionIds(OpOperand *use);
130123

131124
} // namespace mlir::triton::gpu
132125

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,12 +3761,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
37613761
return childOp.emitOpError("does not have expected attribute ")
37623762
<< kPartitionAttrName
37633763
<< " which is expected for ops whose parent has partitions";
3764-
auto ids = *getPartitionIds(&childOp);
3764+
auto ids = getPartitionIds(&childOp);
37653765
expectedIds.insert(ids.begin(), ids.end());
37663766
}
37673767
}
37683768
}
3769-
auto partitionIds = *getPartitionIds(op);
3769+
auto partitionIds = getPartitionIds(op);
37703770
for (auto id : expectedIds) {
37713771
if (!partitionIds.contains(id)) {
37723772
return op->emitOpError("partition ids in attr ")
@@ -3801,13 +3801,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
38013801
return op->emitOpError("does not have expected attribute ")
38023802
<< kPartitionAttrName << " which is expected for ops with attr "
38033803
<< kPartitionOutputsAttrName;
3804-
auto partitionIds = *getPartitionIds(op);
3804+
auto partitionIds = getPartitionIds(op);
38053805

38063806
SetVector<int> outputPartitionIdsUnion;
3807-
for (auto idx = 0; idx < *getNumOutputPartitionIds(op); idx++) {
3808-
auto outputPartitionIds = getOutputPartitionIds(op, idx);
3809-
for (auto partitionId : *outputPartitionIds)
3810-
outputPartitionIdsUnion.insert(partitionId);
3807+
for (auto outputPartitionIds : getPartitionOutputs(op)) {
3808+
outputPartitionIdsUnion.insert(outputPartitionIds.begin(),
3809+
outputPartitionIds.end());
38113810
}
38123811
if (!std::all_of(outputPartitionIdsUnion.begin(),
38133812
outputPartitionIdsUnion.end(),
@@ -3919,15 +3918,8 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
39193918
return dst;
39203919
}
39213920

3922-
std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
3923-
if (!op) {
3924-
return std::nullopt;
3925-
}
3921+
SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
39263922
auto attrs = op->getAttr(kPartitionAttrName);
3927-
if (!attrs) {
3928-
return std::nullopt;
3929-
}
3930-
39313923
SmallVector<int> partitionIds;
39323924
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
39333925
partitionIds.push_back(id);
@@ -3936,33 +3928,31 @@ std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
39363928
return SetVector<int>(partitionIds.begin(), partitionIds.end());
39373929
}
39383930

3939-
std::optional<int> triton::gpu::getNumOutputPartitionIds(Operation *op) {
3940-
if (!op) {
3941-
return std::nullopt;
3931+
SmallVector<SetVector<int>, 4> triton::gpu::getPartitionOutputs(Operation *op) {
3932+
SmallVector<SetVector<int>, 4> partitionOutputsIds;
3933+
if (op->getNumResults() == 0) {
3934+
return partitionOutputsIds;
39423935
}
3943-
auto attr = op->getAttr(kPartitionOutputsAttrName);
3944-
if (!attr) {
3945-
return std::nullopt;
3936+
auto arrayAttr = cast<ArrayAttr>(op->getAttr(kPartitionOutputsAttrName));
3937+
for (auto attr : arrayAttr) {
3938+
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
3939+
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
39463940
}
3947-
return cast<ArrayAttr>(attr).size();
3941+
return partitionOutputsIds;
39483942
}
39493943

3950-
std::optional<SetVector<int>> triton::gpu::getOutputPartitionIds(Operation *op,
3951-
int idx) {
3952-
if (!op) {
3953-
return std::nullopt;
3954-
}
3955-
auto attr = op->getAttr(kPartitionOutputsAttrName);
3956-
if (!attr) {
3957-
return std::nullopt;
3944+
SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) {
3945+
auto owner = use->getOwner();
3946+
if (isa<scf::YieldOp>(owner)) {
3947+
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
3948+
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
3949+
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
3950+
return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp);
3951+
} else {
3952+
return getPartitionIds(owner);
39583953
}
3959-
assert(idx < cast<ArrayAttr>(attr).size());
3960-
auto attrs = cast<ArrayAttr>(attr)[idx];
3954+
}
39613955

3962-
SmallVector<int> partitionIds;
3963-
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
3964-
partitionIds.push_back(id);
3965-
}
3966-
std::sort(partitionIds.begin(), partitionIds.end());
3967-
return SetVector<int>(partitionIds.begin(), partitionIds.end());
3956+
bool triton::gpu::hasPartition(Operation *op) {
3957+
return op && op->hasAttr(kPartitionAttrName);
39683958
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ using namespace triton::gpu;
1515
//===----------------------------------------------------------------------===//
1616

1717
bool Partition::hasOp(Operation *op) const {
18-
auto partitionIds = getPartitionIds(op);
19-
if (!partitionIds) {
18+
if (!hasPartition(op)) {
2019
return false;
2120
}
22-
return partitionIds->contains(getIndex());
21+
auto partitionIds = getPartitionIds(op);
22+
return partitionIds.contains(getIndex());
2323
}
2424

2525
void Partition::iterateInputs(scf::ForOp loop,
@@ -28,7 +28,9 @@ void Partition::iterateInputs(scf::ForOp loop,
2828
visitNestedOperands(op, [&](OpOperand &operand) {
2929
// Ignore implicit captures.
3030
Value value = operand.get();
31-
auto partitionIds = getPartitionIds(value.getDefiningOp());
31+
std::optional<SetVector<int>> partitionIds;
32+
if (hasPartition(value.getDefiningOp()))
33+
partitionIds = getPartitionIds(value.getDefiningOp());
3234
if (value.getParentBlock() != loop.getBody())
3335
return;
3436
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -59,7 +61,9 @@ void Partition::iterateOutputs(
5961
if (!owner) {
6062
continue;
6163
}
62-
auto partitionIds = getPartitionIds(owner);
64+
std::optional<SetVector<int>> partitionIds;
65+
if (hasPartition(owner))
66+
partitionIds = getPartitionIds(owner);
6367
if (isa<scf::YieldOp>(owner)) {
6468
// This value is used in a subsequent iteration.
6569
callback(owner, use);
@@ -123,13 +127,8 @@ const Partition *PartitionSet::getPartition(unsigned idx) const {
123127

124128
Partition *PartitionSet::getPartition(Operation *op) {
125129
auto id = getPartitionIds(op);
126-
assert(id && id->size() == 1);
127-
return getPartition((*id)[0]);
128-
}
129-
130-
bool PartitionSet::isInRootPartition(Operation *op) {
131-
auto partitionIds = getPartitionIds(op);
132-
return !partitionIds || partitionIds->size() == getNumPartitions();
130+
assert(id.size() == 1);
131+
return getPartition(id[0]);
133132
}
134133

135134
FailureOr<PartitionSet> PartitionSet::fromLoop(scf::ForOp loop) {
@@ -155,13 +154,11 @@ FailureOr<PartitionSet> PartitionSet::fromLoop(scf::ForOp loop) {
155154
}
156155

157156
for (Operation &op : loop.getBody()->without_terminator()) {
158-
if (auto attrs = getPartitionIds(&op)) {
159-
for (auto idx : *attrs) {
160-
if (idx < 0 || idx >= result.partitions.size())
161-
return mlir::emitError(op.getLoc(), "invalid partition index ")
162-
<< idx;
163-
result.partitions[idx]->addOp(&op);
164-
}
157+
auto attrs = getPartitionIds(&op);
158+
for (auto idx : attrs) {
159+
if (idx < 0 || idx >= result.partitions.size())
160+
return mlir::emitError(op.getLoc(), "invalid partition index ") << idx;
161+
result.partitions[idx]->addOp(&op);
165162
}
166163
}
167164

@@ -232,35 +229,4 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions) {
232229
setPartition(op, partitionIds);
233230
}
234231

235-
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op) {
236-
if (!op) {
237-
return {};
238-
}
239-
240-
auto attrs = op->getAttr(kPartitionOutputsAttrName);
241-
if (!attrs) {
242-
return {};
243-
}
244-
SmallVector<SetVector<int>, 4> partitionOutputsIds;
245-
for (auto attr : cast<ArrayAttr>(attrs)) {
246-
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
247-
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
248-
}
249-
return partitionOutputsIds;
250-
}
251-
252-
SetVector<int> getPartitionIds(OpOperand *use) {
253-
auto owner = use->getOwner();
254-
if (isa<scf::YieldOp>(owner)) {
255-
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
256-
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
257-
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
258-
return idx >= 0 ? getPartitionOutputs(owner)[idx] : *getPartitionIds(forOp);
259-
} else {
260-
return *getPartitionIds(owner);
261-
}
262-
}
263-
264-
bool hasPartition(Operation *op) { return getPartitionIds(op) != std::nullopt; }
265-
266232
} // namespace mlir::triton::gpu

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ bool isTensorResultComputedBy(scf::ForOp loop, size_t resultIdx,
7272
if (!isa<RankedTensorType>(value.getType()))
7373
return false;
7474
auto defOp = value.getDefiningOp();
75-
auto partitionIds = *getPartitionIds(defOp);
75+
auto partitionIds = getPartitionIds(defOp);
7676
if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
7777
partitionIds = getIfOpResultPartitionIds(ifOp, value);
7878
}
@@ -145,7 +145,7 @@ void cloneOpsInBlock(Block *block, SmallVector<WarpGroupBuilder> &builders,
145145

146146
void cloneForOp(scf::ForOp forOp, SmallVector<WarpGroupBuilder> &builders,
147147
const PartitionSet &partitions) {
148-
auto forOpPartitions = *getPartitionIds(forOp);
148+
auto forOpPartitions = getPartitionIds(forOp);
149149

150150
SmallVector<scf::ForOp> newForOps;
151151
for (int i : forOpPartitions) {
@@ -191,7 +191,7 @@ void cloneForOp(scf::ForOp forOp, SmallVector<WarpGroupBuilder> &builders,
191191

192192
void cloneIfOp(scf::IfOp ifOp, SmallVector<WarpGroupBuilder> &builders,
193193
const PartitionSet &partitions) {
194-
auto partitionIndices = *getPartitionIds(ifOp);
194+
auto partitionIndices = getPartitionIds(ifOp);
195195

196196
SmallVector<scf::IfOp> newIfOps;
197197
for (size_t idx : partitionIndices) {
@@ -239,7 +239,7 @@ void cloneIfOp(scf::IfOp ifOp, SmallVector<WarpGroupBuilder> &builders,
239239
void cloneReduceOp(triton::ReduceOp reduceOp,
240240
SmallVector<WarpGroupBuilder> &builders,
241241
const PartitionSet &partitions) {
242-
auto partitionIndices = *getPartitionIds(reduceOp);
242+
auto partitionIndices = getPartitionIds(reduceOp);
243243

244244
SmallVector<ReduceOp> newReduceOps;
245245
for (size_t idx : partitionIndices) {
@@ -308,7 +308,7 @@ void cloneOpsInBlock(Block *block, SmallVector<WarpGroupBuilder> &builders,
308308
}
309309
// empty yield has no partition annotations
310310
assert(hasPartition(op));
311-
auto partitionIndices = *getPartitionIds(op);
311+
auto partitionIndices = getPartitionIds(op);
312312

313313
for (size_t idx : partitionIndices) {
314314
auto &builder = builders[idx];
@@ -341,7 +341,7 @@ void cloneOpsInBlock(Block *block, SmallVector<WarpGroupBuilder> &builders,
341341
}
342342
} else {
343343
assert(hasPartition(op));
344-
auto partitionIndices = *getPartitionIds(op);
344+
auto partitionIndices = getPartitionIds(op);
345345
cloneOp(op, builders, partitionIndices);
346346
}
347347
}
@@ -359,19 +359,15 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
359359
for (const Partition &partition : partitions.getPartitions()) {
360360
bool failed = false;
361361
auto callback = [&](OpResult output, OpOperand &use, unsigned distance) {
362-
if (partitions.isInRootPartition(output.getDefiningOp())) {
363-
return;
364-
}
365362
auto partitionIds = getPartitionIds(use.getOwner());
366-
if (partitions.isInRootPartition(use.getOwner()) ||
367-
llvm::is_contained(*partitionIds, partition.getIndex()))
363+
if (llvm::is_contained(partitionIds, partition.getIndex()))
368364
return;
369365

370366
// check if consumer partition set is a subset of the producer partitions
371367
auto defOpPartitionIds = getPartitionIds(output.getDefiningOp());
372368
bool isValidSubset = std::all_of(
373-
partitionIds->begin(), partitionIds->end(), [&](int consumerId) {
374-
return llvm::is_contained(*defOpPartitionIds, consumerId);
369+
partitionIds.begin(), partitionIds.end(), [&](int consumerId) {
370+
return llvm::is_contained(defOpPartitionIds, consumerId);
375371
});
376372

377373
if (isValidSubset)
@@ -382,7 +378,7 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
382378
mlir::emitWarning(output.getLoc(), "non-root partition #")
383379
<< partition.getIndex() << " has direct SSA consumer";
384380

385-
for (auto partitionId : *partitionIds) {
381+
for (auto partitionId : partitionIds) {
386382
diag.attachNote(use.getOwner()->getLoc())
387383
<< "use at distance " << distance << " in partition #"
388384
<< partitionId << " here";
@@ -438,9 +434,9 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
438434
auto wsTag = op->getAttrOfType<IntegerAttr>(kWarpSpecializeTagAttrName);
439435
if (!wsTag || wsTag.getInt() != partitions.getTag())
440436
continue;
441-
if (auto partitionIds = triton::gpu::getPartitionIds(op);
442-
partitionIds && !isa<scf::ForOp>(op)) {
443-
cloneOp(op, builders, *partitionIds);
437+
if (hasPartition(op) && !isa<scf::ForOp>(op)) {
438+
auto partitionIds = getPartitionIds(op);
439+
cloneOp(op, builders, partitionIds);
444440
opsToErase.push_back(op);
445441
} else {
446442
assert(loop.getOperation() == op && "Unexpected op");
@@ -451,7 +447,7 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
451447

452448
for (auto [b, region, partition] : llvm::zip(
453449
builders, wgOp.getPartitionRegions(), partitions.getPartitions())) {
454-
if (!llvm::is_contained(*getPartitionIds(loop), b.partitionId)) {
450+
if (!llvm::is_contained(getPartitionIds(loop), b.partitionId)) {
455451
b.create<nvws::WarpGroupYieldOp>(wgOp.getLoc(), SmallVector<Value>{});
456452
continue;
457453
}

0 commit comments

Comments
 (0)