-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SCF] Add interface methods to ParallelCombiningOp for promotion
#159840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Alan Li (lialan) Changes
This patch adds interface methods for the optimizer to promote ops.
Full diff: https://github.com/llvm/llvm-project/pull/159840.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index ba648181daecb..830b49321c2e4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
+LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..4fb4cc8410230 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
- ["getUpdatedDestinations", "getIteratingParent"]>,
+ ["getUpdatedDestinations", "getIteratingParent",
+ "promoteInParallelLoop", "canPromoteInParallelLoop"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"InParallelOpInterface">
]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 82ab427699f64..85cc18c47a527 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -15,6 +15,8 @@
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index ace26f723ef53..1a333d82d8468 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -106,6 +106,27 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Promotes this parallel combining op out of its enclosing parallel loop
+ and returns the value that should replace the destination updated by
+ this op.
+ }],
+ /*retTy=*/"::mlir::FailureOr<::mlir::Value>",
+ /*methodName=*/"promoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this op can be promoted out of its enclosing parallel
+ loop.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canPromoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return false; }]
+ >,
];
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..04737738d8593 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
return failure();
}
- promote(rewriter, *this);
- return success();
+ return promote(rewriter, *this);
}
Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -664,10 +663,21 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
}
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
scf::InParallelOp terminator = forallOp.getTerminator();
+ // Make sure we can promote all parallel combining ops in terminator:
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
+ continue;
+ if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
+ return rewriter.notifyMatchFailure(
+ forallOp, "parallel combining op cannot be promoted");
+ }
+
// Replace block arguments with lower bounds (replacements for IVs) and
// outputs.
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -683,30 +693,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
- auto parallelInsertSliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
- if (!parallelInsertSliceOp)
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
continue;
- Value dst = parallelInsertSliceOp.getDest();
- Value src = parallelInsertSliceOp.getSource();
- if (llvm::isa<TensorType>(src.getType())) {
- results.push_back(tensor::InsertSliceOp::create(
- rewriter, forallOp.getLoc(), dst.getType(), src, dst,
- parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
- parallelInsertSliceOp.getStrides(),
- parallelInsertSliceOp.getStaticOffsets(),
- parallelInsertSliceOp.getStaticSizes(),
- parallelInsertSliceOp.getStaticStrides()));
- } else {
- llvm_unreachable("unsupported terminator");
- }
+ assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
+
+ FailureOr<Value> promotedValue =
+ parallelCombiningOp.promoteInParallelLoop(rewriter);
+ if (failed(promotedValue))
+ return failure();
+
+ results.push_back(*promotedValue);
}
+ if (results.size() != forallOp.getResults().size())
+ return rewriter.notifyMatchFailure(
+ forallOp, "failed to materialize replacements for all results");
rewriter.replaceAllUsesWith(forallOp.getResults(), results);
// Erase the old terminator and the loop.
rewriter.eraseOp(terminator);
rewriter.eraseOp(forallOp);
+ return success();
}
LoopNest mlir::scf::buildLoopNest(
@@ -1789,7 +1798,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
- promote(rewriter, op);
+ if (failed(promote(rewriter, op)))
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49a41d97..f05c58a40fde0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3947,6 +3947,25 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
return nullptr;
}
+FailureOr<Value>
+ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
+ Value dst = getDest();
+ Value src = getSource();
+ if (!isa<TensorType>(src.getType()))
+ return failure();
+
+ Value inserted = tensor::InsertSliceOp::create(
+ rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
+ getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
+
+ return inserted;
+}
+
+bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
+ return isa<TensorType>(getSource().getType()) &&
+ isa<TensorType>(getDest().getType());
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds interface methods to ParallelCombiningOp to enable promotion of operations out of parallel loops. The changes introduce two new interface methods that allow the optimizer to make decisions about and perform promotion in trivial iteration cases.
- Adds
canPromoteInParallelLoopandpromoteInParallelLoopinterface methods - Implements these methods for
ParallelInsertSliceOp - Updates the promotion logic to use the new interface methods instead of hardcoded behavior
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td | Adds new interface methods for promotion support |
| mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h | Adds required header includes for the new interface methods |
| mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | Updates ParallelInsertSliceOp to declare the new interface methods |
| mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | Implements the new interface methods for ParallelInsertSliceOp |
| mlir/include/mlir/Dialect/SCF/IR/SCF.h | Changes promote function signature to return LogicalResult |
| mlir/lib/Dialect/SCF/IR/SCF.cpp | Refactors promotion logic to use new interface and handle failures |
| if (results.size() != forallOp.getResults().size()) | ||
| return rewriter.notifyMatchFailure( | ||
| forallOp, "failed to materialize replacements for all results"); |
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check may incorrectly fail when there are non-ParallelCombiningOpInterface operations in the terminator that don't contribute to results. The logic should only count operations that implement ParallelCombiningOpInterface when comparing against forallOp.getResults().size().
| if (results.size() != forallOp.getResults().size()) | |
| return rewriter.notifyMatchFailure( | |
| forallOp, "failed to materialize replacements for all results"); | |
| // Only count yielding ops that implement ParallelCombiningOpInterface. | |
| size_t numParallelCombiningOps = llvm::count_if( | |
| terminator.getYieldingOps(), [](Operation &op) { | |
| return isa<ParallelCombiningOpInterface>(&op); | |
| }); | |
| if (results.size() != numParallelCombiningOps) | |
| return rewriter.notifyMatchFailure( | |
| forallOp, "failed to materialize replacements for all parallel combining ops"); |
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions h,cpp -- mlir/include/mlir/Dialect/SCF/IR/SCF.h mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h mlir/lib/Dialect/SCF/IR/SCF.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
View the diff from clang-format here.diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 455f40eb8..a34e427bb 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -695,7 +695,8 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
}
/// Promotes the loop body of a scf::ForallOp to its containing block.
-LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
scf::InParallelOp terminator = forallOp.getTerminator();
|
ParallelCombiningOpadds expandability to the parallel insertion of ascf.forall.in_parallelop.This patch adds interface methods for the optimizer to promote ops.
canPromoteInParallelLoopmake decisions whether we can fold/promote in trivial iteration cases.promoteInParallelLoopdoes the actual work.