Skip to content

Commit feb5939

Browse files
committed
Address comments.
1 parent 5e92370 commit feb5939

File tree

5 files changed

+33
-39
lines changed

5 files changed

+33
-39
lines changed

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,40 +59,40 @@ def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
5959

6060
def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
6161
let description = [{
62-
A parallel combining op is an operation performs parallel updates to
63-
destination tensors within the context of a parent iterating operation.
62+
A parallel combining op is an operation that models parallel contributions
63+
to result tensors within the context of a parent iterating operation.
6464

6565
This interface is designed for operations that need to coordinate parallel
66-
insertions or updates to tensors that are being constructed or modified
67-
across multiple parallel iterations. The "updated destination" refers to a
68-
destination tensor that accumulates results from parallel computations,
69-
where each parallel iteration may contribute a slice, element, or region
70-
to the final result.
66+
insertions or contributions to tensors that are being constructed across
67+
multiple parallel iterations. The destination refers to a tensor value that
68+
is assembled by aggregating results from parallel computations; each
69+
parallel iteration may contribute a slice, element, or region to the final
70+
result. No in-place mutation of tensors is implied.
7171

7272
One significant use case for this interface is `tensor.parallel_insert_slice`
73-
which allows parallel insertion of slices into a destination tensor. But with
74-
this interface, other operations that perform similar parallel updates can
75-
also be defined.
73+
which allows parallel insertion of slices that are aggregated into a
74+
destination tensor. With this interface, other operations that express
75+
similar parallel contributions can also be defined.
7676

77-
This op works within an op implementing the
78-
`InParallelOpInterface` that specifies how the parallel results are combined.
77+
This op works within an op implementing the `InParallelOpInterface` that
78+
specifies how the parallel results are combined.
7979

8080
Key semantics:
81-
- The operation identifies destination tensors that will be updated
82-
through the `getUpdatedDestinations` method
83-
- Each parallel iteration may update elements or regions of the
84-
destination tensor
81+
- The operation identifies destination tensors to which iterations
82+
contribute through the `getUpdatedDestinations` method
83+
- Each parallel iteration may produce elements or regions that are
84+
incorporated into the destination tensor
8585
- The parent iterating operation manages the coordination and ensures
86-
proper synchronization of these updates
86+
proper synchronization of these contributions
8787

8888
Note: This interface does not verify itself, it is up to the implementing operation
8989
to verify the correctness of the op.
9090
}];
9191
let cppNamespace = "::mlir";
9292

9393
let methods = [
94-
InterfaceMethod<[{
95-
Returns the list of values updated by this op.
94+
InterfaceMethod<[{
95+
Returns the list of destination values this op contributes to.
9696
}],
9797
/*retTy=*/"::mlir::MutableOperandRange",
9898
/*methodName=*/"getUpdatedDestinations",

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,11 +4142,10 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
41424142
}
41434143

41444144
// If we are inside a `ParallelCombiningOp` region, temporarily set the
4145-
// insertion point outside: only ops implementing ParallelCombiningOpInterface are
4146-
// allowed in there.
4147-
if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation())) {
4145+
// insertion point outside: only ops implementing ParallelCombiningOpInterface
4146+
// are allowed in there.
4147+
if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
41484148
rewriter.setInsertionPoint(target->getParentOp());
4149-
}
41504149

41514150
Value extracted = tensor::ExtractSliceOp::create(
41524151
rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,12 +1674,7 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
16741674
for (OpResult result : forallOp.getResults()) {
16751675
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
16761676
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1677-
SmallVector<Operation *> combiningOps =
1678-
forallOp.getCombiningOps(blockArg);
1679-
if ((result.use_empty() &&
1680-
llvm::all_of(combiningOps,
1681-
[](Operation *op) { return op->use_empty(); })) ||
1682-
combiningOps.empty()) {
1677+
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
16831678
resultToDelete.insert(result);
16841679
} else {
16851680
resultToReplace.push_back(result);
@@ -1981,7 +1976,8 @@ LogicalResult InParallelOp::verify() {
19811976
for (Operation &op : getRegion().front().getOperations()) {
19821977
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&op);
19831978
if (!inParallelOp) {
1984-
return this->emitOpError("expected only ParallelCombiningOpInterface") << " ops";
1979+
return this->emitOpError("expected only ParallelCombiningOpInterface")
1980+
<< " ops";
19851981
}
19861982

19871983
// Verify that inserts are into out block arguments.
@@ -2028,11 +2024,11 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
20282024

20292025
SmallVector<BlockArgument> InParallelOp::getDests() {
20302026
SmallVector<BlockArgument> updatedDests;
2031-
for (auto &yieldingOp : getYieldingOps()) {
2027+
for (Operation &yieldingOp : getYieldingOps()) {
20322028
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
20332029
if (!inParallelOp)
20342030
continue;
2035-
for (auto &updatedOperand : inParallelOp.getUpdatedDestinations())
2031+
for (OpOperand &updatedOperand : inParallelOp.getUpdatedDestinations())
20362032
updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
20372033
}
20382034
return updatedDests;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3846,8 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) {
38463846
//===----------------------------------------------------------------------===//
38473847

38483848
OpResult ParallelInsertSliceOp::getTiedOpResult() {
3849-
InParallelOpInterface parallelCombiningParent =
3850-
getParallelCombiningParent();
3849+
InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
38513850
for (const auto &it :
38523851
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
38533852
Operation &nextOp = it.value();
@@ -3942,8 +3941,8 @@ MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
39423941

39433942
Operation *ParallelInsertSliceOp::getIteratingParent() {
39443943
// Return the parent InParallelOpInterface's parent
3945-
if (auto combiningOp = dyn_cast<InParallelOpInterface>(
3946-
getOperation()->getParentOp())) {
3944+
if (auto combiningOp =
3945+
dyn_cast<InParallelOpInterface>(getOperation()->getParentOp())) {
39473946
return combiningOp->getParentOp();
39483947
}
39493948
return nullptr;

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,9 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
215215
sourceInsertSliceOp.getMixedSizes(),
216216
droppedDims, resolvedSizes);
217217

218-
// If we are inside a ParallelCombining region, temporarily set the insertion
219-
// point outside: only ops of ParallelCombiningOpInterface are allowed in
220-
// there.
218+
// If we are inside a ParallelCombining region, temporarily set the
219+
// insertion point outside: only ops of ParallelCombiningOpInterface are
220+
// allowed in there.
221221
if (isa<mlir::ParallelCombiningOpInterface>(insertSliceOp.getOperation())) {
222222
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
223223
}

0 commit comments

Comments
 (0)