Skip to content

Commit 2156b05

Browse files
authored
[WS] assign partitions to all ops (#8534)
* assign partition to all ops after partition scheduler * scf.for/scf.if/tt.reduce ops and ops in their regions are annotated * region ops also have `ttg.partition.outputs` a list of set of partition for each output * update all passes to remove any inference or assumption how to infere annotations, and use * add verifiers that ensure that ttg.ws regions have all ops annotated * update lit tests this is a first PR to enable nested-loop support cc @acollins3
1 parent 5cb3a6b commit 2156b05

File tree

17 files changed

+1322
-595
lines changed

17 files changed

+1322
-595
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4545
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4646
constexpr static char AttrTargetName[] = "ttg.target";
4747
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
48+
// FIXME: rename to match above
49+
constexpr static char kPartitionAttrName[] = "ttg.partition";
50+
constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
51+
constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages";
52+
constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
4853

4954
// Find the contextual number of warps on which this operation is executed.
5055
int lookupNumWarps(Operation *op);
@@ -293,6 +298,10 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
293298
ShapedType dstTy);
294299
// Verify a memory allocation operation.
295300
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
301+
302+
std::optional<SetVector<int>> getPartitionIds(Operation *op);
303+
std::optional<int> getNumOutputPartitionIds(Operation *op);
304+
std::optional<SetVector<int>> getOutputPartitionIds(Operation *op, int idx);
296305
} // namespace mlir::triton::gpu
297306

298307
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ class ForOp;
1616
} // namespace scf
1717
} // namespace mlir
1818

19-
static constexpr char kPartitionAttrName[] = "ttg.partition";
20-
static constexpr char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
21-
static constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages";
22-
static constexpr char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
23-
2419
//===----------------------------------------------------------------------===//
2520
// PartitionSet
2621
//===----------------------------------------------------------------------===//
@@ -40,6 +35,7 @@ class Partition {
4035
ArrayRef<Operation *> getOps() const { return ops; }
4136
void addOp(Operation *op) { ops.push_back(op); }
4237
bool hasOp(Operation *op) const;
38+
bool empty() const { return ops.empty(); }
4339

4440
// Iterate the inputs of the partition. Input values are those that originate
4541
// from a different partition or a previous iteration of the current
@@ -127,8 +123,9 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions);
127123
// which does not work with Partition instances and iterate* functions, since
128124
// it does not keep the op attributes and the op list of a partition in sync.
129125
void setPartition(Operation *op, const SetVector<int> &partitionIds);
130-
131-
std::optional<SetVector<int>> getPartitionIds(Operation *op);
126+
void setPartitionOutputs(Operation *op,
127+
ArrayRef<SetVector<int>> partitionOutputsIds);
128+
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
132129

133130
} // namespace mlir::triton::gpu
134131

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cstdint>
44
#include <numeric>
55

6+
#include "mlir/Dialect/UB/IR/UBOps.h"
67
#include "mlir/IR/DialectImplementation.h"
78
#include "mlir/IR/OpImplementation.h"
89
#include "mlir/IR/OperationSupport.h"
@@ -573,6 +574,17 @@ static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
573574
return parseBoolAttrValue(parser, attr.getValue(), value, desc);
574575
};
575576

577+
static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr,
578+
Type &value, StringRef desc) {
579+
auto typeAttr = mlir::dyn_cast<TypeAttr>(attr.getValue());
580+
if (!typeAttr) {
581+
parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc;
582+
return failure();
583+
}
584+
value = typeAttr.getValue();
585+
return success();
586+
}
587+
576588
std::optional<LinearLayout>
577589
parseLinearLayout(const DictionaryAttr &dict, AsmParser &parser,
578590
ArrayRef<std::string> inDimNames) {
@@ -3676,6 +3688,136 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
36763688
<< " which is expected only on `module` or `tt.func` ops";
36773689
}
36783690

3691+
// Verify that all ops in a tt.warp_specialize op have partition ids
3692+
if (attr.getName() == "tt.warp_specialize") {
3693+
if (!isa<scf::ForOp>(op)) {
3694+
return op->emitOpError("has unexpected attribute ")
3695+
<< attr.getName() << " which is expected only on `scf.for` ops";
3696+
}
3697+
Operation *failedOp = nullptr;
3698+
op->walk([&](Operation *childOp) {
3699+
if (!childOp->hasAttr(kPartitionAttrName)) {
3700+
failedOp = childOp;
3701+
WalkResult::interrupt();
3702+
}
3703+
});
3704+
if (failedOp) {
3705+
return failedOp->emitOpError("does not have expected attribute ")
3706+
<< kPartitionAttrName
3707+
<< " which is expected on all child ops of an op with "
3708+
"attribute `tt.warp_specialize`";
3709+
}
3710+
}
3711+
3712+
// Verify that partition id lists are non-empty, sorted and have no duplicates
3713+
auto verifyPartitionIds =
3714+
[&](const ArrayRef<int> &partitionIds) -> LogicalResult {
3715+
SetVector<int> idSet;
3716+
for (auto id : partitionIds) {
3717+
if (idSet.contains(id))
3718+
return op->emitOpError("has duplicated partition ids in attribute ")
3719+
<< attr.getName();
3720+
idSet.insert(id);
3721+
}
3722+
if (idSet.empty())
3723+
return op->emitOpError("has no partition ids in attribute ")
3724+
<< attr.getName();
3725+
auto ids = idSet.takeVector();
3726+
SmallVector<int> sortedIds(ids.begin(), ids.end());
3727+
std::sort(sortedIds.begin(), sortedIds.end());
3728+
if (ids != sortedIds)
3729+
return op->emitOpError("partition ids not in sorted order in attribute ")
3730+
<< attr.getName();
3731+
return success();
3732+
};
3733+
3734+
if (attr.getName() == kPartitionAttrName) {
3735+
auto result = verifyPartitionIds(
3736+
cast<DenseI32ArrayAttr>(attr.getValue()).asArrayRef());
3737+
if (failed(result))
3738+
return result;
3739+
}
3740+
if (attr.getName() == kPartitionOutputsAttrName) {
3741+
auto arrayAttr = cast<ArrayAttr>(attr.getValue());
3742+
for (auto idx = 0; idx < arrayAttr.size(); idx++) {
3743+
auto result = verifyPartitionIds(
3744+
cast<DenseI32ArrayAttr>(arrayAttr[idx]).asArrayRef());
3745+
if (failed(result))
3746+
return result;
3747+
}
3748+
}
3749+
3750+
// Verify that op partitions include partitions of all child ops
3751+
if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) {
3752+
SetVector<int> expectedIds;
3753+
for (auto &region : op->getRegions()) {
3754+
for (auto &block : region.getBlocks()) {
3755+
for (auto &childOp : block.getOperations()) {
3756+
if (isa<scf::YieldOp, ub::PoisonOp>(childOp)) {
3757+
// yield ops and ub.poison do not need partition ids
3758+
continue;
3759+
}
3760+
if (!childOp.hasAttr(kPartitionAttrName))
3761+
return childOp.emitOpError("does not have expected attribute ")
3762+
<< kPartitionAttrName
3763+
<< " which is expected for ops whose parent has partitions";
3764+
auto ids = *getPartitionIds(&childOp);
3765+
expectedIds.insert(ids.begin(), ids.end());
3766+
}
3767+
}
3768+
}
3769+
auto partitionIds = *getPartitionIds(op);
3770+
for (auto id : expectedIds) {
3771+
if (!partitionIds.contains(id)) {
3772+
return op->emitOpError("partition ids in attr ")
3773+
<< attr.getName()
3774+
<< " does not contain partition ids of all child ops";
3775+
}
3776+
}
3777+
}
3778+
3779+
if (attr.getName() == kPartitionOutputsAttrName) {
3780+
if (!isa<scf::ForOp, scf::IfOp, triton::ReduceOp>(op))
3781+
return op->emitOpError("has unexpected attribute ") << attr.getName();
3782+
3783+
// Verify that number of output partitions matches number of For/If results
3784+
size_t numResults = 0;
3785+
if (isa<scf::ForOp>(op)) {
3786+
numResults = cast<scf::ForOp>(op).getResults().size();
3787+
} else if (isa<scf::IfOp>(op)) {
3788+
numResults = cast<scf::IfOp>(op).getResults().size();
3789+
} else {
3790+
numResults = cast<triton::ReduceOp>(op).getResults().size();
3791+
}
3792+
3793+
if (cast<ArrayAttr>(attr.getValue()).size() != numResults) {
3794+
return op->emitOpError("does not have expected number of output "
3795+
"partition sets in attr ")
3796+
<< attr.getName() << "; should match number of results";
3797+
}
3798+
3799+
// Verify that union of op output partitions is a subset of op partitions
3800+
if (!op->hasAttr(kPartitionAttrName))
3801+
return op->emitOpError("does not have expected attribute ")
3802+
<< kPartitionAttrName << " which is expected for ops with attr "
3803+
<< kPartitionOutputsAttrName;
3804+
auto partitionIds = *getPartitionIds(op);
3805+
3806+
SetVector<int> outputPartitionIdsUnion;
3807+
for (auto idx = 0; idx < *getNumOutputPartitionIds(op); idx++) {
3808+
auto outputPartitionIds = getOutputPartitionIds(op, idx);
3809+
for (auto partitionId : *outputPartitionIds)
3810+
outputPartitionIdsUnion.insert(partitionId);
3811+
}
3812+
if (!std::all_of(outputPartitionIdsUnion.begin(),
3813+
outputPartitionIdsUnion.end(),
3814+
[&](int id) { return partitionIds.contains(id); })) {
3815+
return op->emitOpError("partition ids in attr ")
3816+
<< kPartitionAttrName
3817+
<< " must be the union of all partition ids in " << attr.getName();
3818+
}
3819+
}
3820+
36793821
return success();
36803822
}
36813823

@@ -3776,3 +3918,51 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
37763918
auto dst = reshapeLayout(ctx, src, dstShape);
37773919
return dst;
37783920
}
3921+
3922+
std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
3923+
if (!op) {
3924+
return std::nullopt;
3925+
}
3926+
auto attrs = op->getAttr(kPartitionAttrName);
3927+
if (!attrs) {
3928+
return std::nullopt;
3929+
}
3930+
3931+
SmallVector<int> partitionIds;
3932+
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
3933+
partitionIds.push_back(id);
3934+
}
3935+
std::sort(partitionIds.begin(), partitionIds.end());
3936+
return SetVector<int>(partitionIds.begin(), partitionIds.end());
3937+
}
3938+
3939+
std::optional<int> triton::gpu::getNumOutputPartitionIds(Operation *op) {
3940+
if (!op) {
3941+
return std::nullopt;
3942+
}
3943+
auto attr = op->getAttr(kPartitionOutputsAttrName);
3944+
if (!attr) {
3945+
return std::nullopt;
3946+
}
3947+
return cast<ArrayAttr>(attr).size();
3948+
}
3949+
3950+
std::optional<SetVector<int>> triton::gpu::getOutputPartitionIds(Operation *op,
3951+
int idx) {
3952+
if (!op) {
3953+
return std::nullopt;
3954+
}
3955+
auto attr = op->getAttr(kPartitionOutputsAttrName);
3956+
if (!attr) {
3957+
return std::nullopt;
3958+
}
3959+
assert(idx < cast<ArrayAttr>(attr).size());
3960+
auto attrs = cast<ArrayAttr>(attr)[idx];
3961+
3962+
SmallVector<int> partitionIds;
3963+
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
3964+
partitionIds.push_back(id);
3965+
}
3966+
std::sort(partitionIds.begin(), partitionIds.end());
3967+
return SetVector<int>(partitionIds.begin(), partitionIds.end());
3968+
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ void Partition::iterateOutputs(
5656
for (Operation *op : getOps()) {
5757
for (OpOperand &use : op->getUses()) {
5858
Operation *owner = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
59+
if (!owner) {
60+
continue;
61+
}
5962
auto partitionIds = getPartitionIds(owner);
6063
if (isa<scf::YieldOp>(owner)) {
6164
// This value is used in a subsequent iteration.
@@ -88,6 +91,9 @@ void Partition::iterateUses(
8891
while (!uses.empty()) {
8992
auto [output, use, distance] = uses.pop_back_val();
9093
Operation *owner = loop.getBody()->findAncestorOpInBlock(*use->getOwner());
94+
if (!owner) {
95+
continue;
96+
}
9197
if (!isa<scf::YieldOp>(owner)) {
9298
callback(output, *use, distance);
9399
continue;
@@ -179,7 +185,31 @@ namespace mlir::triton::gpu {
179185

180186
void setPartition(Operation *op, ArrayRef<int> partitionIds) {
181187
Builder b(op->getContext());
182-
op->setAttr(kPartitionAttrName, b.getDenseI32ArrayAttr(partitionIds));
188+
auto sorted = llvm::to_vector(partitionIds);
189+
llvm::sort(sorted);
190+
op->setAttr(kPartitionAttrName, b.getDenseI32ArrayAttr(sorted));
191+
for (auto &region : op->getRegions()) {
192+
for (auto &block : region.getBlocks()) {
193+
auto terminator = block.getTerminator();
194+
terminator->setAttr(kPartitionAttrName, b.getDenseI32ArrayAttr(sorted));
195+
}
196+
}
197+
}
198+
199+
void setPartitionOutputs(Operation *op,
200+
ArrayRef<SetVector<int>> partitionOutputsIds) {
201+
if (partitionOutputsIds.empty()) {
202+
op->removeAttr(kPartitionOutputsAttrName);
203+
return;
204+
}
205+
SmallVector<Attribute> attrs;
206+
Builder b(op->getContext());
207+
for (auto partitionIds : partitionOutputsIds) {
208+
auto sorted = llvm::to_vector(partitionIds);
209+
llvm::sort(sorted);
210+
attrs.push_back(b.getDenseI32ArrayAttr(sorted));
211+
}
212+
op->setAttr(kPartitionOutputsAttrName, b.getArrayAttr(attrs));
183213
}
184214

185215
void setPartition(Operation *op, const SetVector<int> &partitionIds) {
@@ -202,22 +232,21 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions) {
202232
setPartition(op, partitionIds);
203233
}
204234

205-
std::optional<SetVector<int>> getPartitionIds(Operation *op) {
235+
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op) {
206236
if (!op) {
207-
return std::nullopt;
237+
return {};
208238
}
209-
auto attrs = op->getAttr(kPartitionAttrName);
239+
240+
auto attrs = op->getAttr(kPartitionOutputsAttrName);
210241
if (!attrs) {
211-
return std::nullopt;
242+
return {};
212243
}
213-
214-
assert(isa<DenseI32ArrayAttr>(attrs));
215-
216-
SetVector<int> partitionIds;
217-
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
218-
partitionIds.insert(id);
244+
SmallVector<SetVector<int>, 4> partitionOutputsIds;
245+
for (auto attr : cast<ArrayAttr>(attrs)) {
246+
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
247+
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
219248
}
220-
return partitionIds;
249+
return partitionOutputsIds;
221250
}
222251

223252
bool hasPartition(Operation *op) { return getPartitionIds(op) != std::nullopt; }

0 commit comments

Comments
 (0)