Skip to content

Commit 551f35e

Browse files
committed
Update. Related test: test/Dialect/SCF/transform-ops.mlir
1 parent a9073c0 commit 551f35e

File tree

5 files changed

+17
-15
lines changed

5 files changed

+17
-15
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
14751475
OffsetSizeAndStrideOpInterface,
14761476
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
14771477
["getUpdatedDestinations", "getIteratingParent",
1478-
"promoteInParallelLoop"]>,
1478+
"promoteInParallelLoop", "canPromoteInParallelLoop"]>,
14791479
// TODO: Cannot use an interface here atm, verify this manually for now.
14801480
// HasParent<"InParallelOpInterface">
14811481
]> {

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/Support/LogicalResult.h"
20-
#include "llvm/ADT/SmallVector.h"
2120

2221
namespace mlir {
2322
namespace detail {

mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
109109
InterfaceMethod<
110110
/*desc=*/[{
111111
Promotes this parallel combining op out of its enclosing parallel loop
112-
and returns the values that should replace the destinations updated by
112+
and returns the value that should replace the destination updated by
113113
this op.
114114
}],
115-
/*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
115+
/*retTy=*/"::mlir::FailureOr<::mlir::Value>",
116116
/*methodName=*/"promoteInParallelLoop",
117117
/*args=*/(ins "::mlir::RewriterBase &":$rewriter)
118118
>,
@@ -124,7 +124,8 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
124124
/*retTy=*/"bool",
125125
/*methodName=*/"canPromoteInParallelLoop",
126126
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
127-
/*methodBody=*/[{ return true; }]
127+
/*methodBody=*/"",
128+
/*defaultImplementation=*/[{ return false; }]
128129
>,
129130
];
130131
}

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,14 +672,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
672672
auto parallelCombiningOp =
673673
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
674674
if (!parallelCombiningOp)
675-
return rewriter.notifyMatchFailure(
676-
forallOp, "terminator has non-parallel-combining op");
675+
continue;
677676
if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
678677
return rewriter.notifyMatchFailure(
679678
forallOp, "parallel combining op cannot be promoted");
680679
}
681680

682-
683681
// Replace block arguments with lower bounds (replacements for IVs) and
684682
// outputs.
685683
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -702,12 +700,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
702700

703701
assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
704702

705-
FailureOr<SmallVector<Value>> promotedValues =
703+
FailureOr<Value> promotedValue =
706704
parallelCombiningOp.promoteInParallelLoop(rewriter);
707-
if (failed(promotedValues))
705+
if (failed(promotedValue))
708706
return failure();
709707

710-
results.append(promotedValues->begin(), promotedValues->end());
708+
results.push_back(*promotedValue);
711709
}
712710
if (results.size() != forallOp.getResults().size())
713711
return rewriter.notifyMatchFailure(

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3947,19 +3947,23 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
39473947
return nullptr;
39483948
}
39493949

3950-
FailureOr<SmallVector<Value>>
3950+
FailureOr<Value>
39513951
ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
39523952
Value dst = getDest();
39533953
Value src = getSource();
39543954
if (!isa<TensorType>(src.getType()))
3955-
return rewriter.notifyMatchFailure(getOperation(),
3956-
"expected tensor source");
3955+
return failure();
39573956

39583957
Value inserted = tensor::InsertSliceOp::create(
39593958
rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
39603959
getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
39613960

3962-
return SmallVector<Value>{inserted};
3961+
return inserted;
3962+
}
3963+
3964+
bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
3965+
return isa<TensorType>(getSource().getType()) &&
3966+
isa<TensorType>(getDest().getType());
39633967
}
39643968

39653969
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)