From a9073c056efe40e7a36917aec3025c97b668ed35 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 18 Sep 2025 13:29:16 -0400 Subject: [PATCH 1/2] Adding a missing op for ParallelCombiningOpInterface --- mlir/include/mlir/Dialect/SCF/IR/SCF.h | 2 +- .../mlir/Dialect/Tensor/IR/TensorOps.td | 3 +- .../Interfaces/ParallelCombiningOpInterface.h | 3 ++ .../ParallelCombiningOpInterface.td | 20 +++++++ mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 15 ++++++ 6 files changed, 73 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index ba648181daecb..830b49321c2e4 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val); bool insideMutuallyExclusiveBranches(Operation *a, Operation *b); /// Promotes the loop body of a scf::ForallOp to its containing block. -void promote(RewriterBase &rewriter, scf::ForallOp forallOp); +LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp); /// An owning vector of values, handy to return from functions. using ValueVector = SmallVector; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2453cf5b5b5a4..be04c3a4aebbe 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, DeclareOpInterfaceMethods, + ["getUpdatedDestinations", "getIteratingParent", + "promoteInParallelLoop"]>, // TODO: Cannot use an interface here atm, verify this manually for now. // HasParent<"InParallelOpInterface"> ]> { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h index 82ab427699f64..ff4e5a87d05c7 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -15,6 +15,9 @@ #define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { namespace detail { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td index ace26f723ef53..632371b2777fd 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -106,6 +106,26 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { /*methodName=*/"getIteratingParent", /*args=*/(ins) >, + InterfaceMethod< + /*desc=*/[{ + Promotes this parallel combining op out of its enclosing parallel loop + and returns the values that should replace the destinations updated by + this op. + }], + /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>", + /*methodName=*/"promoteInParallelLoop", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter) + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if this op can be promoted out of its enclosing parallel + loop. + }], + /*retTy=*/"bool", + /*methodName=*/"canPromoteInParallelLoop", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter), + /*methodBody=*/[{ return true; }] + >, ]; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index c35989ecba6cd..4115ca00f64b5 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { return failure(); } - promote(rewriter, *this); - return success(); + return promote(rewriter, *this); } Block::BlockArgListType ForallOp::getRegionIterArgs() { @@ -664,10 +663,23 @@ MutableArrayRef ForallOp::getInitsMutable() { } /// Promotes the loop body of a scf::ForallOp to its containing block. -void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { +LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { OpBuilder::InsertionGuard g(rewriter); scf::InParallelOp terminator = forallOp.getTerminator(); + // Make sure we can promote all parallel combining ops in terminator: + for (auto &yieldingOp : terminator.getYieldingOps()) { + auto parallelCombiningOp = + dyn_cast(&yieldingOp); + if (!parallelCombiningOp) + return rewriter.notifyMatchFailure( + forallOp, "terminator has non-parallel-combining op"); + if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter)) + return rewriter.notifyMatchFailure( + forallOp, "parallel combining op cannot be promoted"); + } + + // Replace block arguments with lower bounds (replacements for IVs) and // outputs. SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter); @@ -683,30 +695,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { SmallVector results; results.reserve(forallOp.getResults().size()); for (auto &yieldingOp : terminator.getYieldingOps()) { - auto parallelInsertSliceOp = - dyn_cast(yieldingOp); - if (!parallelInsertSliceOp) + auto parallelCombiningOp = + dyn_cast(&yieldingOp); + if (!parallelCombiningOp) continue; - Value dst = parallelInsertSliceOp.getDest(); - Value src = parallelInsertSliceOp.getSource(); - if (llvm::isa(src.getType())) { - results.push_back(tensor::InsertSliceOp::create( - rewriter, forallOp.getLoc(), dst.getType(), src, dst, - parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), - parallelInsertSliceOp.getStrides(), - parallelInsertSliceOp.getStaticOffsets(), - parallelInsertSliceOp.getStaticSizes(), - parallelInsertSliceOp.getStaticStrides())); - } else { - llvm_unreachable("unsupported terminator"); - } + assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter)); + + FailureOr> promotedValues = + parallelCombiningOp.promoteInParallelLoop(rewriter); + if (failed(promotedValues)) + return failure(); + + results.append(promotedValues->begin(), promotedValues->end()); } + if (results.size() != forallOp.getResults().size()) + return rewriter.notifyMatchFailure( + forallOp, "failed to materialize replacements for all results"); rewriter.replaceAllUsesWith(forallOp.getResults(), results); // Erase the old terminator and the loop. rewriter.eraseOp(terminator); rewriter.eraseOp(forallOp); + return success(); } LoopNest mlir::scf::buildLoopNest( @@ -1789,7 +1800,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder // All of the loop dimensions perform a single iteration. Inline loop body. if (newMixedLowerBounds.empty()) { - promote(rewriter, op); + if (failed(promote(rewriter, op))) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49a41d97..2932000b85b3b 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3947,6 +3947,21 @@ Operation *ParallelInsertSliceOp::getIteratingParent() { return nullptr; } +FailureOr> +ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) { + Value dst = getDest(); + Value src = getSource(); + if (!isa(src.getType())) + return rewriter.notifyMatchFailure(getOperation(), + "expected tensor source"); + + Value inserted = tensor::InsertSliceOp::create( + rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(), + getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides()); + + return SmallVector{inserted}; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// From 551f35e3d6c4890b2c5f16e48bd2c1d1d9645bed Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 18 Sep 2025 14:00:55 -0400 Subject: [PATCH 2/2] Update. Related test: test/Dialect/SCF/transform-ops.mlir --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +- .../mlir/Interfaces/ParallelCombiningOpInterface.h | 1 - .../mlir/Interfaces/ParallelCombiningOpInterface.td | 7 ++++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 10 ++++------ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ++++++++---- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index be04c3a4aebbe..4fb4cc8410230 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1475,7 +1475,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ OffsetSizeAndStrideOpInterface, DeclareOpInterfaceMethods, + "promoteInParallelLoop", "canPromoteInParallelLoop"]>, // TODO: Cannot use an interface here atm, verify this manually for now. // HasParent<"InParallelOpInterface"> ]> { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h index ff4e5a87d05c7..85cc18c47a527 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -17,7 +17,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/SmallVector.h" namespace mlir { namespace detail { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td index 632371b2777fd..1a333d82d8468 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -109,10 +109,10 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { InterfaceMethod< /*desc=*/[{ Promotes this parallel combining op out of its enclosing parallel loop - and returns the values that should replace the destinations updated by + and returns the value that should replace the destination updated by this op. }], - /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>", + /*retTy=*/"::mlir::FailureOr<::mlir::Value>", /*methodName=*/"promoteInParallelLoop", /*args=*/(ins "::mlir::RewriterBase &":$rewriter) >, @@ -124,7 +124,8 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { /*retTy=*/"bool", /*methodName=*/"canPromoteInParallelLoop", /*args=*/(ins "::mlir::RewriterBase &":$rewriter), - /*methodBody=*/[{ return true; }] + /*methodBody=*/"", + /*defaultImplementation=*/[{ return false; }] >, ]; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 4115ca00f64b5..04737738d8593 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -672,14 +672,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) auto parallelCombiningOp = dyn_cast(&yieldingOp); if (!parallelCombiningOp) - return rewriter.notifyMatchFailure( - forallOp, "terminator has non-parallel-combining op"); + continue; if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter)) return rewriter.notifyMatchFailure( forallOp, "parallel combining op cannot be promoted"); } - // Replace block arguments with lower bounds (replacements for IVs) and // outputs. SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter); @@ -702,12 +700,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter)); - FailureOr> promotedValues = + FailureOr promotedValue = parallelCombiningOp.promoteInParallelLoop(rewriter); - if (failed(promotedValues)) + if (failed(promotedValue)) return failure(); - results.append(promotedValues->begin(), promotedValues->end()); + results.push_back(*promotedValue); } if (results.size() != forallOp.getResults().size()) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2932000b85b3b..f05c58a40fde0 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3947,19 +3947,23 @@ Operation *ParallelInsertSliceOp::getIteratingParent() { return nullptr; } -FailureOr> +FailureOr ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) { Value dst = getDest(); Value src = getSource(); if (!isa(src.getType())) - return rewriter.notifyMatchFailure(getOperation(), - "expected tensor source"); + return failure(); Value inserted = tensor::InsertSliceOp::create( rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(), getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides()); - return SmallVector{inserted}; + return inserted; +} + +bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) { + return isa(getSource().getType()) && + isa(getDest().getType()); } //===----------------------------------------------------------------------===//