Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<InParallelOpInterface,
["getUpdatedDestinations", "getIteratingParent"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
LogicalResult verifyInParallelOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

Expand Down
29 changes: 29 additions & 0 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,33 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
}];
}

def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
let description = [{
An in_parallel op is an operation that inserts into a shared tensor in
conjunction with a parent combining and iterating op.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the list of values updated by this op.
}],
/*retTy=*/"::mlir::MutableOperandRange",
/*methodName=*/"getUpdatedDestinations",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Returns the iterating parent for this op.
}],
/*retTy=*/"::mlir::Operation*",
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
];
let verify = [{
return ::mlir::detail::verifyInParallelOpInterface($_op);
}];
}

#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
14 changes: 9 additions & 5 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -4140,11 +4141,14 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
return DiagnosedSilenceableFailure::success();
}

// If we are inside an InParallel region, temporarily set the insertion point
// outside: only tensor.parallel_insert_slice ops are allowed in there.
if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
rewriter.setInsertionPoint(
target->template getParentOfType<scf::InParallelOp>());
// If we are inside an ParallelCombiningOp region, temporarily set the
// insertion point outside: only ops implementing InParallelOpInterface are
// allowed in there.
if (isa<mlir::InParallelOpInterface>(target.getOperation())) {
if (auto combiningParent =
dyn_cast<ParallelCombiningOpInterface>(target->getParentOp())) {
rewriter.setInsertionPoint(target->getParentOp());
}
}

Value extracted = tensor::ExtractSliceOp::create(
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,12 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is the insertion point is just before the ParallelCombiningOp in the
// parallel case.
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) {
if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
insertSliceOp->getParentOp())) {
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
}
}
reshapedSource = tensor::CollapseShapeOp::create(
rewriter, loc, insertSliceOp.getSource(), *reassociation);
}
Expand Down
55 changes: 26 additions & 29 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,11 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
// Skip non-ParallelInsertSliceOp operations
auto parallelInsertSliceOp =
cast<tensor::ParallelInsertSliceOp>(yieldingOp);
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
if (!parallelInsertSliceOp)
continue;

Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
Expand Down Expand Up @@ -1437,14 +1440,12 @@ InParallelOp ForallOp::getTerminator() {
return cast<InParallelOp>(getBody()->getTerminator());
}


SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
InParallelOp inParallelOp = getTerminator();
for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
if (auto parallelInsertSliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
storeOps.push_back(parallelInsertSliceOp);
for (Operation *user : bbArg.getUsers()) {
if (auto parallelOp = dyn_cast<InParallelOpInterface>(user)) {
storeOps.push_back(parallelOp);
}
}
return storeOps;
Expand Down Expand Up @@ -1673,7 +1674,12 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
SmallVector<Operation *> combiningOps =
forallOp.getCombiningOps(blockArg);
if ((result.use_empty() &&
llvm::all_of(combiningOps,
[](Operation *op) { return op->use_empty(); })) ||
combiningOps.empty()) {
resultToDelete.insert(result);
} else {
resultToReplace.push_back(result);
Expand Down Expand Up @@ -1911,8 +1917,9 @@ struct FoldTensorCastOfOutputIntoForallOp
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
insertSliceOp.getDestMutable().assign(outputBlockArg);
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
if (insertSliceOp)
insertSliceOp.getDestMutable().assign(outputBlockArg);
}

// Cast results back to the original types.
Expand Down Expand Up @@ -1971,19 +1978,6 @@ LogicalResult InParallelOp::verify() {
if (!forallOp)
return this->emitOpError("expected forall op parent");

// TODO: InParallelOpInterface.
for (Operation &op : getRegion().front().getOperations()) {
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
return this->emitOpError("expected only ")
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
}

// Verify that inserts are into out block arguments.
Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
if (!llvm::is_contained(regionOutArgs, dest))
return op.emitOpError("may only insert into an output block argument");
}
return success();
}

Expand Down Expand Up @@ -2018,12 +2012,15 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
}

SmallVector<BlockArgument> InParallelOp::getDests() {
return llvm::to_vector<4>(
llvm::map_range(getYieldingOps(), [](Operation &op) {
// Add new ops here as needed.
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
return llvm::cast<BlockArgument>(insertSliceOp.getDest());
}));
SmallVector<BlockArgument> updatedDests;
for (auto &yieldingOp : getYieldingOps()) {
auto inParallelOp = dyn_cast<InParallelOpInterface>(&yieldingOp);
if (!inParallelOp)
continue;
for (auto &updatedOperand : inParallelOp.getUpdatedDestinations())
updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
}
return updatedDests;
}

llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ namespace {
/// <implicit in_parallel terminator here>
/// }
/// ```
struct InParallelOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
struct InParallelDeallocOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelDeallocOpInterface,
scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
Expand Down Expand Up @@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,7 @@ class InsertSliceOpConstantArgumentFolder final
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is that the insertion point is just before the ParallelCombiningOp in
// the parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
sourceType, toInsert);
Expand Down Expand Up @@ -3155,7 +3155,7 @@ struct InsertSliceOpSourceCastInserter final
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
// that the insertion point is just before the ParallelCombiningOp in the
// parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
newSrcType, insertSliceOp.getSource());
Expand Down Expand Up @@ -3901,10 +3901,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}

LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());

// Verify result type against inferred type.
RankedTensorType expectedType;
SliceVerificationResult result =
Expand Down Expand Up @@ -3935,6 +3931,20 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}

// InParallelOpInterface implementation
MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
return getDestMutable();
}

Operation *ParallelInsertSliceOp::getIteratingParent() {
// Return the parent ParallelCombiningOpInterface's parent
if (auto combiningOp = dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp())) {
return combiningOp->getParentOp();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,9 @@ struct ParallelInsertSliceOpInterface
parallelInsertSliceOp.getParallelCombiningParent();

// Bufferize the op outside of the parallel combining terminator.
rewriter.setInsertionPoint(parallelCombiningParent);
if (parallelCombiningParent) {
rewriter.setInsertionPoint(parallelCombiningParent);
}

// Get source and destination buffers.
FailureOr<Value> destBuffer =
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
// point outside: only tensor.parallel_insert_slice ops are allowed in
// there.
if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
rewriter.setInsertionPoint(
insertSliceOp->template getParentOfType<scf::InParallelOp>());
if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
insertSliceOp->getParentOp())) {
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
}
}

// Resolve offsets according to source offsets and strides.
Expand Down
37 changes: 33 additions & 4 deletions mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,47 @@

using namespace mlir;

/// Include the definitions of the interface.
#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"

//===----------------------------------------------------------------------===//
// ParallelCombiningOpInterface
// InParallelOpInterface
//===----------------------------------------------------------------------===//

// TODO: Catch-22 with interface methods used to verify means methods can't
// assume the impl is valid.
LogicalResult mlir::detail::verifyInParallelOpInterface(Operation *op) {
auto inParallel = cast<InParallelOpInterface>(op);
auto parent = inParallel.getIteratingParent();
if (!parent) {
return op->emitError(
"in_parallel interface op must have an iterating parent");
}

// Simple verification without requiring ParallelIterationOpInterface
// Just check that updated destinations are block arguments
for (OpOperand &updatedValue : inParallel.getUpdatedDestinations()) {
auto bbArg = dyn_cast<BlockArgument>(updatedValue.get());
if (!bbArg) {
return op->emitError("updating a non block argument");
}
}
return success();
}


//===----------------------------------------------------------------------===//
// ParallelCombiningOpInterface
//===----------------------------------------------------------------------===//
// TODO: Single region single block interface on interfaces ?
LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitError("expected single region op");
if (!op->getRegion(0).hasOneBlock())
return op->emitError("expected single block op region");
for (Operation &child : *op->getRegion(0).getBlocks().begin()) {
if (!isa<InParallelOpInterface>(&child))
return op->emitError("expected only in_parallel interface ops");
}
return success();
}

/// Include the definitions of the interface.
#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
26 changes: 25 additions & 1 deletion mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -915,14 +915,15 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te

// -----

func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
func.func @parallel_insert_slice() -> tensor<4x2xf32> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x2xf32>
%res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
%1 = tensor.empty() : tensor<1x1xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
// CHECK: scf.forall.in_parallel
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
// CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
Expand All @@ -935,6 +936,29 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {

// -----

// CHECK-LABEL: func @parallel_insert_slice_no_terminator
func.func @parallel_insert_slice_no_terminator() -> tensor<4x2xf32> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x2xf32>
// CHECK: scf.forall
%res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
%1 = tensor.empty() : tensor<1x1xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
// CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
tensor<1x1xf32> into tensor<4x2xf32>
}
}
return %res: tensor<4x2xf32>
}

// -----

#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
Expand Down
Loading
Loading