Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);

/// Promotes the loop body of a scf::ForallOp to its containing block.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);

/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
["getUpdatedDestinations", "getIteratingParent"]>,
["getUpdatedDestinations", "getIteratingParent",
"promoteInParallelLoop", "canPromoteInParallelLoop"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"InParallelOpInterface">
]> {
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_

#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir {
namespace detail {
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Promotes this parallel combining op out of its enclosing parallel loop
and returns the value that should replace the destination updated by
this op.
}],
/*retTy=*/"::mlir::FailureOr<::mlir::Value>",
/*methodName=*/"promoteInParallelLoop",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<
/*desc=*/[{
Returns true if this op can be promoted out of its enclosing parallel
loop.
}],
/*retTy=*/"bool",
/*methodName=*/"canPromoteInParallelLoop",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
/*methodBody=*/"",
/*defaultImplementation=*/[{ return false; }]
>,
];
}

Expand Down
50 changes: 30 additions & 20 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
return failure();
}

promote(rewriter, *this);
return success();
return promote(rewriter, *this);
}

Block::BlockArgListType ForallOp::getRegionIterArgs() {
Expand All @@ -664,10 +663,21 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
}

/// Promotes the loop body of a scf::ForallOp to its containing block.
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
scf::InParallelOp terminator = forallOp.getTerminator();

// Make sure we can promote all parallel combining ops in terminator:
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelCombiningOp =
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
if (!parallelCombiningOp)
continue;
if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
return rewriter.notifyMatchFailure(
forallOp, "parallel combining op cannot be promoted");
}

// Replace block arguments with lower bounds (replacements for IVs) and
// outputs.
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
Expand All @@ -683,30 +693,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelInsertSliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
if (!parallelInsertSliceOp)
auto parallelCombiningOp =
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
if (!parallelCombiningOp)
continue;

Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
if (llvm::isa<TensorType>(src.getType())) {
results.push_back(tensor::InsertSliceOp::create(
rewriter, forallOp.getLoc(), dst.getType(), src, dst,
parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
parallelInsertSliceOp.getStrides(),
parallelInsertSliceOp.getStaticOffsets(),
parallelInsertSliceOp.getStaticSizes(),
parallelInsertSliceOp.getStaticStrides()));
} else {
llvm_unreachable("unsupported terminator");
}
assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));

FailureOr<Value> promotedValue =
parallelCombiningOp.promoteInParallelLoop(rewriter);
if (failed(promotedValue))
return failure();

results.push_back(*promotedValue);
}
if (results.size() != forallOp.getResults().size())
return rewriter.notifyMatchFailure(
forallOp, "failed to materialize replacements for all results");
Comment on lines +710 to +712
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check may incorrectly fail when there are non-ParallelCombiningOpInterface operations in the terminator that don't contribute to results. The logic should only count operations that implement ParallelCombiningOpInterface when comparing against forallOp.getResults().size().

Suggested change
if (results.size() != forallOp.getResults().size())
return rewriter.notifyMatchFailure(
forallOp, "failed to materialize replacements for all results");
// Only count yielding ops that implement ParallelCombiningOpInterface.
size_t numParallelCombiningOps = llvm::count_if(
terminator.getYieldingOps(), [](Operation &op) {
return isa<ParallelCombiningOpInterface>(&op);
});
if (results.size() != numParallelCombiningOps)
return rewriter.notifyMatchFailure(
forallOp, "failed to materialize replacements for all parallel combining ops");

Copilot uses AI. Check for mistakes.
rewriter.replaceAllUsesWith(forallOp.getResults(), results);

// Erase the old terminator and the loop.
rewriter.eraseOp(terminator);
rewriter.eraseOp(forallOp);
return success();
}

LoopNest mlir::scf::buildLoopNest(
Expand Down Expand Up @@ -1789,7 +1798,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder

// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
promote(rewriter, op);
if (failed(promote(rewriter, op)))
return failure();
return success();
}

Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3947,6 +3947,25 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
return nullptr;
}

FailureOr<Value>
ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
Value dst = getDest();
Value src = getSource();
if (!isa<TensorType>(src.getType()))
return failure();

Value inserted = tensor::InsertSliceOp::create(
rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());

return inserted;
}

bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
return isa<TensorType>(getSource().getType()) &&
isa<TensorType>(getDest().getType());
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
Loading