|
3 | 3 | #include <cstdint> |
4 | 4 | #include <numeric> |
5 | 5 |
|
| 6 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
6 | 7 | #include "mlir/IR/DialectImplementation.h" |
7 | 8 | #include "mlir/IR/OpImplementation.h" |
8 | 9 | #include "mlir/IR/OperationSupport.h" |
@@ -573,6 +574,17 @@ static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, |
573 | 574 | return parseBoolAttrValue(parser, attr.getValue(), value, desc); |
574 | 575 | }; |
575 | 576 |
|
| 577 | +static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr, |
| 578 | + Type &value, StringRef desc) { |
| 579 | + auto typeAttr = mlir::dyn_cast<TypeAttr>(attr.getValue()); |
| 580 | + if (!typeAttr) { |
| 581 | + parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc; |
| 582 | + return failure(); |
| 583 | + } |
| 584 | + value = typeAttr.getValue(); |
| 585 | + return success(); |
| 586 | +} |
| 587 | + |
576 | 588 | std::optional<LinearLayout> |
577 | 589 | parseLinearLayout(const DictionaryAttr &dict, AsmParser &parser, |
578 | 590 | ArrayRef<std::string> inDimNames) { |
@@ -3676,6 +3688,136 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, |
3676 | 3688 | << " which is expected only on `module` or `tt.func` ops"; |
3677 | 3689 | } |
3678 | 3690 |
|
| 3691 | + // Verify that all ops in a tt.warp_specialize op have partition ids |
| 3692 | + if (attr.getName() == "tt.warp_specialize") { |
| 3693 | + if (!isa<scf::ForOp>(op)) { |
| 3694 | + return op->emitOpError("has unexpected attribute ") |
| 3695 | + << attr.getName() << " which is expected only on `scf.for` ops"; |
| 3696 | + } |
| 3697 | + Operation *failedOp = nullptr; |
| 3698 | + op->walk([&](Operation *childOp) { |
| 3699 | + if (!childOp->hasAttr(kPartitionAttrName)) { |
| 3700 | + failedOp = childOp; |
| 3701 | + WalkResult::interrupt(); |
| 3702 | + } |
| 3703 | + }); |
| 3704 | + if (failedOp) { |
| 3705 | + return failedOp->emitOpError("does not have expected attribute ") |
| 3706 | + << kPartitionAttrName |
| 3707 | + << " which is expected on all child ops of an op with " |
| 3708 | + "attribute `tt.warp_specialize`"; |
| 3709 | + } |
| 3710 | + } |
| 3711 | + |
| 3712 | + // Verify that partition id lists are non-empty, sorted and have no duplicates |
| 3713 | + auto verifyPartitionIds = |
| 3714 | + [&](const ArrayRef<int> &partitionIds) -> LogicalResult { |
| 3715 | + SetVector<int> idSet; |
| 3716 | + for (auto id : partitionIds) { |
| 3717 | + if (idSet.contains(id)) |
| 3718 | + return op->emitOpError("has duplicated partition ids in attribute ") |
| 3719 | + << attr.getName(); |
| 3720 | + idSet.insert(id); |
| 3721 | + } |
| 3722 | + if (idSet.empty()) |
| 3723 | + return op->emitOpError("has no partition ids in attribute ") |
| 3724 | + << attr.getName(); |
| 3725 | + auto ids = idSet.takeVector(); |
| 3726 | + SmallVector<int> sortedIds(ids.begin(), ids.end()); |
| 3727 | + std::sort(sortedIds.begin(), sortedIds.end()); |
| 3728 | + if (ids != sortedIds) |
| 3729 | + return op->emitOpError("partition ids not in sorted order in attribute ") |
| 3730 | + << attr.getName(); |
| 3731 | + return success(); |
| 3732 | + }; |
| 3733 | + |
| 3734 | + if (attr.getName() == kPartitionAttrName) { |
| 3735 | + auto result = verifyPartitionIds( |
| 3736 | + cast<DenseI32ArrayAttr>(attr.getValue()).asArrayRef()); |
| 3737 | + if (failed(result)) |
| 3738 | + return result; |
| 3739 | + } |
| 3740 | + if (attr.getName() == kPartitionOutputsAttrName) { |
| 3741 | + auto arrayAttr = cast<ArrayAttr>(attr.getValue()); |
| 3742 | + for (auto idx = 0; idx < arrayAttr.size(); idx++) { |
| 3743 | + auto result = verifyPartitionIds( |
| 3744 | + cast<DenseI32ArrayAttr>(arrayAttr[idx]).asArrayRef()); |
| 3745 | + if (failed(result)) |
| 3746 | + return result; |
| 3747 | + } |
| 3748 | + } |
| 3749 | + |
| 3750 | + // Verify that op partitions include partitions of all child ops |
| 3751 | + if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) { |
| 3752 | + SetVector<int> expectedIds; |
| 3753 | + for (auto ®ion : op->getRegions()) { |
| 3754 | + for (auto &block : region.getBlocks()) { |
| 3755 | + for (auto &childOp : block.getOperations()) { |
| 3756 | + if (isa<scf::YieldOp, ub::PoisonOp>(childOp)) { |
| 3757 | + // yield ops and ub.poison do not need partition ids |
| 3758 | + continue; |
| 3759 | + } |
| 3760 | + if (!childOp.hasAttr(kPartitionAttrName)) |
| 3761 | + return childOp.emitOpError("does not have expected attribute ") |
| 3762 | + << kPartitionAttrName |
| 3763 | + << " which is expected for ops whose parent has partitions"; |
| 3764 | + auto ids = *getPartitionIds(&childOp); |
| 3765 | + expectedIds.insert(ids.begin(), ids.end()); |
| 3766 | + } |
| 3767 | + } |
| 3768 | + } |
| 3769 | + auto partitionIds = *getPartitionIds(op); |
| 3770 | + for (auto id : expectedIds) { |
| 3771 | + if (!partitionIds.contains(id)) { |
| 3772 | + return op->emitOpError("partition ids in attr ") |
| 3773 | + << attr.getName() |
| 3774 | + << " does not contain partition ids of all child ops"; |
| 3775 | + } |
| 3776 | + } |
| 3777 | + } |
| 3778 | + |
| 3779 | + if (attr.getName() == kPartitionOutputsAttrName) { |
| 3780 | + if (!isa<scf::ForOp, scf::IfOp, triton::ReduceOp>(op)) |
| 3781 | + return op->emitOpError("has unexpected attribute ") << attr.getName(); |
| 3782 | + |
| 3783 | + // Verify that number of output partitions matches number of For/If results |
| 3784 | + size_t numResults = 0; |
| 3785 | + if (isa<scf::ForOp>(op)) { |
| 3786 | + numResults = cast<scf::ForOp>(op).getResults().size(); |
| 3787 | + } else if (isa<scf::IfOp>(op)) { |
| 3788 | + numResults = cast<scf::IfOp>(op).getResults().size(); |
| 3789 | + } else { |
| 3790 | + numResults = cast<triton::ReduceOp>(op).getResults().size(); |
| 3791 | + } |
| 3792 | + |
| 3793 | + if (cast<ArrayAttr>(attr.getValue()).size() != numResults) { |
| 3794 | + return op->emitOpError("does not have expected number of output " |
| 3795 | + "partition sets in attr ") |
| 3796 | + << attr.getName() << "; should match number of results"; |
| 3797 | + } |
| 3798 | + |
| 3799 | + // Verify that union of op output partitions is a subset of op partitions |
| 3800 | + if (!op->hasAttr(kPartitionAttrName)) |
| 3801 | + return op->emitOpError("does not have expected attribute ") |
| 3802 | + << kPartitionAttrName << " which is expected for ops with attr " |
| 3803 | + << kPartitionOutputsAttrName; |
| 3804 | + auto partitionIds = *getPartitionIds(op); |
| 3805 | + |
| 3806 | + SetVector<int> outputPartitionIdsUnion; |
| 3807 | + for (auto idx = 0; idx < *getNumOutputPartitionIds(op); idx++) { |
| 3808 | + auto outputPartitionIds = getOutputPartitionIds(op, idx); |
| 3809 | + for (auto partitionId : *outputPartitionIds) |
| 3810 | + outputPartitionIdsUnion.insert(partitionId); |
| 3811 | + } |
| 3812 | + if (!std::all_of(outputPartitionIdsUnion.begin(), |
| 3813 | + outputPartitionIdsUnion.end(), |
| 3814 | + [&](int id) { return partitionIds.contains(id); })) { |
| 3815 | + return op->emitOpError("partition ids in attr ") |
| 3816 | + << kPartitionAttrName |
| 3817 | + << " must be the union of all partition ids in " << attr.getName(); |
| 3818 | + } |
| 3819 | + } |
| 3820 | + |
3679 | 3821 | return success(); |
3680 | 3822 | } |
3681 | 3823 |
|
@@ -3776,3 +3918,51 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy, |
3776 | 3918 | auto dst = reshapeLayout(ctx, src, dstShape); |
3777 | 3919 | return dst; |
3778 | 3920 | } |
| 3921 | + |
| 3922 | +std::optional<SetVector<int>> triton::gpu::getPartitionIds(Operation *op) { |
| 3923 | + if (!op) { |
| 3924 | + return std::nullopt; |
| 3925 | + } |
| 3926 | + auto attrs = op->getAttr(kPartitionAttrName); |
| 3927 | + if (!attrs) { |
| 3928 | + return std::nullopt; |
| 3929 | + } |
| 3930 | + |
| 3931 | + SmallVector<int> partitionIds; |
| 3932 | + for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) { |
| 3933 | + partitionIds.push_back(id); |
| 3934 | + } |
| 3935 | + std::sort(partitionIds.begin(), partitionIds.end()); |
| 3936 | + return SetVector<int>(partitionIds.begin(), partitionIds.end()); |
| 3937 | +} |
| 3938 | + |
| 3939 | +std::optional<int> triton::gpu::getNumOutputPartitionIds(Operation *op) { |
| 3940 | + if (!op) { |
| 3941 | + return std::nullopt; |
| 3942 | + } |
| 3943 | + auto attr = op->getAttr(kPartitionOutputsAttrName); |
| 3944 | + if (!attr) { |
| 3945 | + return std::nullopt; |
| 3946 | + } |
| 3947 | + return cast<ArrayAttr>(attr).size(); |
| 3948 | +} |
| 3949 | + |
| 3950 | +std::optional<SetVector<int>> triton::gpu::getOutputPartitionIds(Operation *op, |
| 3951 | + int idx) { |
| 3952 | + if (!op) { |
| 3953 | + return std::nullopt; |
| 3954 | + } |
| 3955 | + auto attr = op->getAttr(kPartitionOutputsAttrName); |
| 3956 | + if (!attr) { |
| 3957 | + return std::nullopt; |
| 3958 | + } |
| 3959 | + assert(idx < cast<ArrayAttr>(attr).size()); |
| 3960 | + auto attrs = cast<ArrayAttr>(attr)[idx]; |
| 3961 | + |
| 3962 | + SmallVector<int> partitionIds; |
| 3963 | + for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) { |
| 3964 | + partitionIds.push_back(id); |
| 3965 | + } |
| 3966 | + std::sort(partitionIds.begin(), partitionIds.end()); |
| 3967 | + return SetVector<int>(partitionIds.begin(), partitionIds.end()); |
| 3968 | +} |
0 commit comments