Skip to content

Commit b87f1b2

Browse files
authored
[MLIR] Add InParallelOpInterface for parallel combining operations (llvm#157736)
This commit: - Introduces a new `InParallelOpInterface`, along with the `ParallelCombiningOpInterface`, represent the parallel updating operations we have in a parallel loop of `scf.forall`. - Change the name of `ParallelCombiningOpInterface` to `InParallelOpInterface` as the naming was quite confusing. - `ParallelCombiningOpInterface` now is used to generalize operations that insert into shared tensors within parallel combining regions. Previously, only `tensor.parallel_insert_slice` was supported directly in `scf.InParallelOp` regions. - `tensor.parallel_insert_slice` now implements `ParallelCombiningOpInterface`. This change enables future extensions to support additional parallel combining operations beyond `tensor.parallel_insert_slice`, which have different update semantics, so the `in_parallel` region can correctly and safely represent these kinds of operation without potential mistakes such as races. Author credits: @qedawkins
1 parent f645d20 commit b87f1b2

File tree

16 files changed

+280
-69
lines changed

16 files changed

+280
-69
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def ForallOp : SCF_Op<"forall", [
654654
def InParallelOp : SCF_Op<"forall.in_parallel", [
655655
Pure,
656656
Terminator,
657-
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
657+
DeclareOpInterfaceMethods<InParallelOpInterface>,
658658
HasParent<"ForallOp">,
659659
] # GraphRegionNoTerminator.traits> {
660660
let summary = "terminates a `forall` block";
@@ -679,8 +679,6 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
679679
OpBuilder<(ins)>,
680680
];
681681

682-
// TODO: Add a `InParallelOpInterface` interface for ops that can
683-
// appear inside in_parallel.
684682
let extraClassDeclaration = [{
685683
::llvm::SmallVector<::mlir::BlockArgument> getDests();
686684
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,24 +1470,25 @@ def Tensor_PadOp : Tensor_Op<"pad", [
14701470
// ParallelInsertSliceOp
14711471
//===----------------------------------------------------------------------===//
14721472

1473-
// TODO: Implement InParallelOpInterface.
14741473
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
14751474
AttrSizedOperandSegments,
14761475
OffsetSizeAndStrideOpInterface,
1476+
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
1477+
["getUpdatedDestinations", "getIteratingParent"]>,
14771478
// TODO: Cannot use an interface here atm, verify this manually for now.
1478-
// HasParent<"ParallelCombiningOpInterface">
1479+
// HasParent<"InParallelOpInterface">
14791480
]> {
14801481
let summary = [{
14811482
Specify the tensor slice update of a single thread of a parent
1482-
ParallelCombiningOpInterface op.
1483+
InParallelOpInterface op.
14831484
}];
14841485
let description = [{
14851486
The `parallel_insert_slice` yields a subset tensor value to its parent
1486-
ParallelCombiningOpInterface. These subset tensor values are aggregated to
1487+
InParallelOpInterface. These subset tensor values are aggregated to
14871488
in some unspecified order into a full tensor value returned by the parent
14881489
parallel iterating op.
14891490
The `parallel_insert_slice` is one such op allowed in the
1490-
ParallelCombiningOpInterface op.
1491+
InParallelOpInterface op.
14911492

14921493
Conflicting writes result in undefined semantics, in that the indices written
14931494
to by multiple parallel updates might contain data from any of the updates,
@@ -1569,8 +1570,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
15691570
return ::llvm::cast<RankedTensorType>(getDest().getType());
15701571
}
15711572

1572-
ParallelCombiningOpInterface getParallelCombiningParent() {
1573-
return dyn_cast<ParallelCombiningOpInterface>(
1573+
InParallelOpInterface getParallelCombiningParent() {
1574+
return dyn_cast<InParallelOpInterface>(
15741575
getOperation()->getParentOp());
15751576
}
15761577

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
namespace mlir {
2020
namespace detail {
2121
// TODO: Single region single block interface on interfaces ?
22-
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
22+
LogicalResult verifyInParallelOpInterface(Operation *op);
2323
} // namespace detail
2424
} // namespace mlir
2525

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// Defines the interface for ops that perform parallel combining operations.
9+
// Defines the interface for ops that perform in parallel combining
10+
// operations.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -15,9 +16,9 @@
1516

1617
include "mlir/IR/OpBase.td"
1718

18-
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
19+
def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
1920
let description = [{
20-
A parallel combining op is an op with a region.
21+
An in parallel op is an op with a region.
2122

2223
This is useful as a terminator to parallel operations that iterate over
2324
some set and return tensors while avoiding tight coupling between the
@@ -52,8 +53,60 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
5253
];
5354
// TODO: Single region single block interface on interfaces ?
5455
let verify = [{
55-
return verifyParallelCombiningOpInterface($_op);
56+
return verifyInParallelOpInterface($_op);
57+
}];
58+
}
59+
60+
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
61+
let description = [{
62+
A parallel combining op is an operation that models parallel contributions
63+
to result tensors within the context of a parent iterating operation.
64+
65+
This interface is designed for operations that need to coordinate parallel
66+
insertions or contributions to tensors that are being constructed across
67+
multiple parallel iterations. The destination refers to a tensor value that
68+
is assembled by aggregating results from parallel computations; each
69+
parallel iteration may contribute a slice, element, or region to the final
70+
result. No in-place mutation of tensors is implied.
71+
72+
One significant use case for this interface is `tensor.parallel_insert_slice`
73+
which allows parallel insertion of slices that are aggregated into a
74+
destination tensor. With this interface, other operations that express
75+
similar parallel contributions can also be defined.
76+
77+
This op works within an op implementing the `InParallelOpInterface` that
78+
specifies how the parallel results are combined.
79+
80+
Key semantics:
81+
- The operation identifies destination tensors to which iterations
82+
contribute through the `getUpdatedDestinations` method
83+
- Each parallel iteration may produce elements or regions that are
84+
incorporated into the destination tensor
85+
- The parent iterating operation manages the coordination and ensures
86+
proper synchronization of these contributions
87+
88+
Note: This interface does not verify itself, it is up to the implementing operation
89+
to verify the correctness of the op.
5690
}];
91+
let cppNamespace = "::mlir";
92+
93+
let methods = [
94+
InterfaceMethod<[{
95+
Returns the list of destination values this op contributes to.
96+
}],
97+
/*retTy=*/"::mlir::MutableOperandRange",
98+
/*methodName=*/"getUpdatedDestinations",
99+
/*args=*/(ins)
100+
>,
101+
InterfaceMethod<
102+
/*desc=*/[{
103+
Returns the iterating parent for this op.
104+
}],
105+
/*retTy=*/"::mlir::Operation*",
106+
/*methodName=*/"getIteratingParent",
107+
/*args=*/(ins)
108+
>,
109+
];
57110
}
58111

59112
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/IR/BuiltinTypeInterfaces.h"
3737
#include "mlir/IR/PatternMatch.h"
3838
#include "mlir/IR/TypeUtilities.h"
39+
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
3940
#include "mlir/Interfaces/TilingInterface.h"
4041
#include "mlir/Support/LLVM.h"
4142
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -4147,12 +4148,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
41474148
return DiagnosedSilenceableFailure::success();
41484149
}
41494150

4150-
// If we are inside an InParallel region, temporarily set the insertion point
4151-
// outside: only tensor.parallel_insert_slice ops are allowed in there.
4152-
if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4153-
rewriter.setInsertionPoint(
4154-
target->template getParentOfType<scf::InParallelOp>());
4155-
}
4151+
// If we are inside a `ParallelCombiningOp` region, temporarily set the
4152+
// insertion point outside: only ops implementing ParallelCombiningOpInterface
4153+
// are allowed in there.
4154+
if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4155+
rewriter.setInsertionPoint(target->getParentOp());
41564156

41574157
Value extracted = tensor::ExtractSliceOp::create(
41584158
rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/Matchers.h"
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/Interfaces/FunctionInterfaces.h"
24+
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
2425
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2526
#include "mlir/Transforms/InliningUtils.h"
2627
#include "llvm/ADT/MapVector.h"
@@ -681,7 +682,9 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
681682
results.reserve(forallOp.getResults().size());
682683
for (auto &yieldingOp : terminator.getYieldingOps()) {
683684
auto parallelInsertSliceOp =
684-
cast<tensor::ParallelInsertSliceOp>(yieldingOp);
685+
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
686+
if (!parallelInsertSliceOp)
687+
continue;
685688

686689
Value dst = parallelInsertSliceOp.getDest();
687690
Value src = parallelInsertSliceOp.getSource();
@@ -1439,12 +1442,9 @@ InParallelOp ForallOp::getTerminator() {
14391442

14401443
SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
14411444
SmallVector<Operation *> storeOps;
1442-
InParallelOp inParallelOp = getTerminator();
1443-
for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1444-
if (auto parallelInsertSliceOp =
1445-
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1446-
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1447-
storeOps.push_back(parallelInsertSliceOp);
1445+
for (Operation *user : bbArg.getUsers()) {
1446+
if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1447+
storeOps.push_back(parallelOp);
14481448
}
14491449
}
14501450
return storeOps;
@@ -1911,8 +1911,10 @@ struct FoldTensorCastOfOutputIntoForallOp
19111911
auto terminator = newForallOp.getTerminator();
19121912
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
19131913
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1914-
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1915-
insertSliceOp.getDestMutable().assign(outputBlockArg);
1914+
if (auto parallelCombingingOp =
1915+
dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1916+
parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1917+
}
19161918
}
19171919

19181920
// Cast results back to the original types.
@@ -1971,19 +1973,22 @@ LogicalResult InParallelOp::verify() {
19711973
if (!forallOp)
19721974
return this->emitOpError("expected forall op parent");
19731975

1974-
// TODO: InParallelOpInterface.
19751976
for (Operation &op : getRegion().front().getOperations()) {
1976-
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1977-
return this->emitOpError("expected only ")
1978-
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1977+
auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1978+
if (!parallelCombiningOp) {
1979+
return this->emitOpError("expected only ParallelCombiningOpInterface")
1980+
<< " ops";
19791981
}
19801982

19811983
// Verify that inserts are into out block arguments.
1982-
Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1984+
MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
19831985
ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1984-
if (!llvm::is_contained(regionOutArgs, dest))
1985-
return op.emitOpError("may only insert into an output block argument");
1986+
for (OpOperand &dest : dests) {
1987+
if (!llvm::is_contained(regionOutArgs, dest.get()))
1988+
return op.emitOpError("may only insert into an output block argument");
1989+
}
19861990
}
1991+
19871992
return success();
19881993
}
19891994

@@ -2018,12 +2023,17 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
20182023
}
20192024

20202025
SmallVector<BlockArgument> InParallelOp::getDests() {
2021-
return llvm::to_vector<4>(
2022-
llvm::map_range(getYieldingOps(), [](Operation &op) {
2023-
// Add new ops here as needed.
2024-
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2025-
return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2026-
}));
2026+
SmallVector<BlockArgument> updatedDests;
2027+
for (Operation &yieldingOp : getYieldingOps()) {
2028+
auto parallelCombiningOp =
2029+
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2030+
if (!parallelCombiningOp)
2031+
continue;
2032+
for (OpOperand &updatedOperand :
2033+
parallelCombiningOp.getUpdatedDestinations())
2034+
updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2035+
}
2036+
return updatedDests;
20272037
}
20282038

20292039
llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {

mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using namespace mlir::bufferization;
1616
namespace {
1717
/// The `scf.forall.in_parallel` terminator is special in a few ways:
1818
/// * It does not implement the BranchOpInterface or
19-
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
19+
/// RegionBranchTerminatorOpInterface, but the InParallelOpInterface
2020
/// which is not supported by BufferDeallocation.
2121
/// * It has a graph-like region which only allows one specific tensor op
2222
/// * After bufferization the nested region is always empty
@@ -40,9 +40,9 @@ namespace {
4040
/// <implicit in_parallel terminator here>
4141
/// }
4242
/// ```
43-
struct InParallelOpInterface
44-
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
45-
scf::InParallelOp> {
43+
struct InParallelDeallocOpInterface
44+
: public BufferDeallocationOpInterface::ExternalModel<
45+
InParallelDeallocOpInterface, scf::InParallelOp> {
4646
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
4747
const DeallocationOptions &options) const {
4848
auto inParallelOp = cast<scf::InParallelOp>(op);
@@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
7575
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
7676
DialectRegistry &registry) {
7777
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
78-
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
78+
InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
7979
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
8080
});
8181
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,9 +2976,9 @@ class InsertSliceOpConstantArgumentFolder final
29762976
if (sourceType != insertSliceOp.getSourceType()) {
29772977
OpBuilder::InsertionGuard g(rewriter);
29782978
// The only difference between InsertSliceOp and ParallelInsertSliceOp
2979-
// is that the insertion point is just before the ParallelCombiningOp in
2979+
// is that the insertion point is just before the InParallelOp in
29802980
// the parallel case.
2981-
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2981+
if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
29822982
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
29832983
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
29842984
sourceType, toInsert);
@@ -3153,9 +3153,9 @@ struct InsertSliceOpSourceCastInserter final
31533153
// Insert the cast.
31543154
OpBuilder::InsertionGuard g(rewriter);
31553155
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
3156-
// that the insertion point is just before the ParallelCombiningOp in the
3156+
// that the insertion point is just before the InParallelOp in the
31573157
// parallel case.
3158-
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3158+
if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
31593159
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
31603160
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
31613161
newSrcType, insertSliceOp.getSource());
@@ -3846,8 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) {
38463846
//===----------------------------------------------------------------------===//
38473847

38483848
OpResult ParallelInsertSliceOp::getTiedOpResult() {
3849-
ParallelCombiningOpInterface parallelCombiningParent =
3850-
getParallelCombiningParent();
3849+
InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
38513850
for (const auto &it :
38523851
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
38533852
Operation &nextOp = it.value();
@@ -3901,8 +3900,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
39013900
}
39023901

39033902
LogicalResult ParallelInsertSliceOp::verify() {
3904-
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3905-
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3903+
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3904+
return this->emitError("expected InParallelOpInterface parent, got:")
39063905
<< *(getOperation()->getParentOp());
39073906

39083907
// Verify result type against inferred type.
@@ -3935,6 +3934,19 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
39353934
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
39363935
}
39373936

3937+
// ParallelCombiningOpInterface implementation.
3938+
MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3939+
return getDestMutable();
3940+
}
3941+
3942+
Operation *ParallelInsertSliceOp::getIteratingParent() {
3943+
// Return the parent InParallelOpInterface's parent.
3944+
if (auto combiningOp =
3945+
dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3946+
return combiningOp->getParentOp();
3947+
return nullptr;
3948+
}
3949+
39383950
//===----------------------------------------------------------------------===//
39393951
// ScatterOp
39403952
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,10 +970,10 @@ struct ParallelInsertSliceOpInterface
970970
BufferizationState &state) const {
971971
OpBuilder::InsertionGuard g(rewriter);
972972
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
973-
ParallelCombiningOpInterface parallelCombiningParent =
973+
InParallelOpInterface parallelCombiningParent =
974974
parallelInsertSliceOp.getParallelCombiningParent();
975975

976-
// Bufferize the op outside of the parallel combining terminator.
976+
// Bufferize the op outside of the in parallel terminator.
977977
rewriter.setInsertionPoint(parallelCombiningParent);
978978

979979
// Get source and destination buffers.

0 commit comments

Comments
 (0)