@@ -3814,12 +3814,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
38143814 return childOp.emitOpError (" does not have expected attribute " )
38153815 << kPartitionAttrName
38163816 << " which is expected for ops whose parent has partitions" ;
3817- auto ids = * getPartitionIds (&childOp);
3817+ auto ids = getPartitionIds (&childOp);
38183818 expectedIds.insert (ids.begin (), ids.end ());
38193819 }
38203820 }
38213821 }
3822- auto partitionIds = * getPartitionIds (op);
3822+ auto partitionIds = getPartitionIds (op);
38233823 for (auto id : expectedIds) {
38243824 if (!partitionIds.contains (id)) {
38253825 return op->emitOpError (" partition ids in attr " )
@@ -3854,13 +3854,12 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
38543854 return op->emitOpError (" does not have expected attribute " )
38553855 << kPartitionAttrName << " which is expected for ops with attr "
38563856 << kPartitionOutputsAttrName ;
3857- auto partitionIds = * getPartitionIds (op);
3857+ auto partitionIds = getPartitionIds (op);
38583858
38593859 SetVector<int > outputPartitionIdsUnion;
3860- for (auto idx = 0 ; idx < *getNumOutputPartitionIds (op); idx++) {
3861- auto outputPartitionIds = getOutputPartitionIds (op, idx);
3862- for (auto partitionId : *outputPartitionIds)
3863- outputPartitionIdsUnion.insert (partitionId);
3860+ for (auto outputPartitionIds : getPartitionOutputs (op)) {
3861+ outputPartitionIdsUnion.insert (outputPartitionIds.begin (),
3862+ outputPartitionIds.end ());
38643863 }
38653864 if (!std::all_of (outputPartitionIdsUnion.begin (),
38663865 outputPartitionIdsUnion.end (),
@@ -3972,15 +3971,8 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
39723971 return dst;
39733972}
39743973
3975- std::optional<SetVector<int >> triton::gpu::getPartitionIds (Operation *op) {
3976- if (!op) {
3977- return std::nullopt ;
3978- }
3974+ SetVector<int > triton::gpu::getPartitionIds (Operation *op) {
39793975 auto attrs = op->getAttr (kPartitionAttrName );
3980- if (!attrs) {
3981- return std::nullopt ;
3982- }
3983-
39843976 SmallVector<int > partitionIds;
39853977 for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef ()) {
39863978 partitionIds.push_back (id);
@@ -3989,33 +3981,31 @@ std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) {
39893981 return SetVector<int >(partitionIds.begin (), partitionIds.end ());
39903982}
39913983
3992- std::optional<int > triton::gpu::getNumOutputPartitionIds (Operation *op) {
3993- if (!op) {
3994- return std::nullopt ;
3984+ SmallVector<SetVector<int >, 4 > triton::gpu::getPartitionOutputs (Operation *op) {
3985+ SmallVector<SetVector<int >, 4 > partitionOutputsIds;
3986+ if (op->getNumResults () == 0 ) {
3987+ return partitionOutputsIds;
39953988 }
3996- auto attr = op->getAttr (kPartitionOutputsAttrName );
3997- if (!attr) {
3998- return std::nullopt ;
3989+ auto arrayAttr = cast<ArrayAttr>(op->getAttr (kPartitionOutputsAttrName ));
3990+ for (auto attr : arrayAttr) {
3991+ auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef ();
3992+ partitionOutputsIds.push_back (SetVector<int >(ids.begin (), ids.end ()));
39993993 }
4000- return cast<ArrayAttr>(attr). size () ;
3994+ return partitionOutputsIds ;
40013995}
40023996
4003- std::optional<SetVector<int >> triton::gpu::getOutputPartitionIds (Operation *op,
4004- int idx) {
4005- if (!op) {
4006- return std::nullopt ;
4007- }
4008- auto attr = op->getAttr (kPartitionOutputsAttrName );
4009- if (!attr) {
4010- return std::nullopt ;
3997+ SetVector<int > triton::gpu::getPartitionIds (OpOperand *use) {
3998+ auto owner = use->getOwner ();
3999+ if (isa<scf::YieldOp>(owner)) {
4000+ return getPartitionOutputs (owner->getParentOp ())[use->getOperandNumber ()];
4001+ } else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
4002+ int idx = use->getOperandNumber () - forOp.getNumControlOperands ();
4003+ return idx >= 0 ? getPartitionOutputs (owner)[idx] : getPartitionIds (forOp);
4004+ } else {
4005+ return getPartitionIds (owner);
40114006 }
4012- assert (idx < cast<ArrayAttr>(attr).size ());
4013- auto attrs = cast<ArrayAttr>(attr)[idx];
4007+ }
40144008
4015- SmallVector<int > partitionIds;
4016- for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef ()) {
4017- partitionIds.push_back (id);
4018- }
4019- std::sort (partitionIds.begin (), partitionIds.end ());
4020- return SetVector<int >(partitionIds.begin (), partitionIds.end ());
4009+ bool triton::gpu::hasPartition (Operation *op) {
4010+ return op && op->hasAttr (kPartitionAttrName );
40214011}
0 commit comments