Skip to content

Commit a9073c0

Browse files
committed
Adding a missing op for ParallelCombiningOpInterface
1 parent 7c861bc commit a9073c0

File tree

6 files changed

+73
-22
lines changed

6 files changed

+73
-22
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCF.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
5858
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
5959

6060
/// Promotes the loop body of a scf::ForallOp to its containing block.
61-
void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
61+
LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
6262

6363
/// An owning vector of values, handy to return from functions.
6464
using ValueVector = SmallVector<Value>;

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
14741474
AttrSizedOperandSegments,
14751475
OffsetSizeAndStrideOpInterface,
14761476
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
1477-
["getUpdatedDestinations", "getIteratingParent"]>,
1477+
["getUpdatedDestinations", "getIteratingParent",
1478+
"promoteInParallelLoop"]>,
14781479
// TODO: Cannot use an interface here atm, verify this manually for now.
14791480
// HasParent<"InParallelOpInterface">
14801481
]> {

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
1616

1717
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
#include "llvm/ADT/SmallVector.h"
1821

1922
namespace mlir {
2023
namespace detail {

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,26 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
106106
/*methodName=*/"getIteratingParent",
107107
/*args=*/(ins)
108108
>,
109+
InterfaceMethod<
110+
/*desc=*/[{
111+
Promotes this parallel combining op out of its enclosing parallel loop
112+
and returns the values that should replace the destinations updated by
113+
this op.
114+
}],
115+
/*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
116+
/*methodName=*/"promoteInParallelLoop",
117+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter)
118+
>,
119+
InterfaceMethod<
120+
/*desc=*/[{
121+
Returns true if this op can be promoted out of its enclosing parallel
122+
loop.
123+
}],
124+
/*retTy=*/"bool",
125+
/*methodName=*/"canPromoteInParallelLoop",
126+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
127+
/*methodBody=*/[{ return true; }]
128+
>,
109129
];
110130
}
111131

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
651651
return failure();
652652
}
653653

654-
promote(rewriter, *this);
655-
return success();
654+
return promote(rewriter, *this);
656655
}
657656

658657
Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -664,10 +663,23 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
664663
}
665664

666665
/// Promotes the loop body of a scf::ForallOp to its containing block.
667-
void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
666+
LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
668667
OpBuilder::InsertionGuard g(rewriter);
669668
scf::InParallelOp terminator = forallOp.getTerminator();
670669

670+
// Make sure we can promote all parallel combining ops in terminator:
671+
for (auto &yieldingOp : terminator.getYieldingOps()) {
672+
auto parallelCombiningOp =
673+
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
674+
if (!parallelCombiningOp)
675+
return rewriter.notifyMatchFailure(
676+
forallOp, "terminator has non-parallel-combining op");
677+
if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
678+
return rewriter.notifyMatchFailure(
679+
forallOp, "parallel combining op cannot be promoted");
680+
}
681+
682+
671683
// Replace block arguments with lower bounds (replacements for IVs) and
672684
// outputs.
673685
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -683,30 +695,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
683695
SmallVector<Value> results;
684696
results.reserve(forallOp.getResults().size());
685697
for (auto &yieldingOp : terminator.getYieldingOps()) {
686-
auto parallelInsertSliceOp =
687-
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
688-
if (!parallelInsertSliceOp)
698+
auto parallelCombiningOp =
699+
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
700+
if (!parallelCombiningOp)
689701
continue;
690702

691-
Value dst = parallelInsertSliceOp.getDest();
692-
Value src = parallelInsertSliceOp.getSource();
693-
if (llvm::isa<TensorType>(src.getType())) {
694-
results.push_back(tensor::InsertSliceOp::create(
695-
rewriter, forallOp.getLoc(), dst.getType(), src, dst,
696-
parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
697-
parallelInsertSliceOp.getStrides(),
698-
parallelInsertSliceOp.getStaticOffsets(),
699-
parallelInsertSliceOp.getStaticSizes(),
700-
parallelInsertSliceOp.getStaticStrides()));
701-
} else {
702-
llvm_unreachable("unsupported terminator");
703-
}
703+
assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
704+
705+
FailureOr<SmallVector<Value>> promotedValues =
706+
parallelCombiningOp.promoteInParallelLoop(rewriter);
707+
if (failed(promotedValues))
708+
return failure();
709+
710+
results.append(promotedValues->begin(), promotedValues->end());
704711
}
712+
if (results.size() != forallOp.getResults().size())
713+
return rewriter.notifyMatchFailure(
714+
forallOp, "failed to materialize replacements for all results");
705715
rewriter.replaceAllUsesWith(forallOp.getResults(), results);
706716

707717
// Erase the old terminator and the loop.
708718
rewriter.eraseOp(terminator);
709719
rewriter.eraseOp(forallOp);
720+
return success();
710721
}
711722

712723
LoopNest mlir::scf::buildLoopNest(
@@ -1789,7 +1800,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17891800

17901801
// All of the loop dimensions perform a single iteration. Inline loop body.
17911802
if (newMixedLowerBounds.empty()) {
1792-
promote(rewriter, op);
1803+
if (failed(promote(rewriter, op)))
1804+
return failure();
17931805
return success();
17941806
}
17951807

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3947,6 +3947,21 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
39473947
return nullptr;
39483948
}
39493949

3950+
FailureOr<SmallVector<Value>>
3951+
ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
3952+
Value dst = getDest();
3953+
Value src = getSource();
3954+
if (!isa<TensorType>(src.getType()))
3955+
return rewriter.notifyMatchFailure(getOperation(),
3956+
"expected tensor source");
3957+
3958+
Value inserted = tensor::InsertSliceOp::create(
3959+
rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
3960+
getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
3961+
3962+
return SmallVector<Value>{inserted};
3963+
}
3964+
39503965
//===----------------------------------------------------------------------===//
39513966
// ScatterOp
39523967
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)