Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def ForallOp : SCF_Op<"forall", [
def InParallelOp : SCF_Op<"forall.in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"ForallOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `forall` block";
Expand All @@ -679,8 +679,6 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
OpBuilder<(ins)>,
];

// TODO: Add a `InParallelOpInterface` interface for ops that can
// appear inside in_parallel.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::BlockArgument> getDests();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
Expand Down
15 changes: 8 additions & 7 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1470,24 +1470,25 @@ def Tensor_PadOp : Tensor_Op<"pad", [
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//

// TODO: Implement InParallelOpInterface.
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
["getUpdatedDestinations", "getIteratingParent"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
// HasParent<"InParallelOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
ParallelCombiningOpInterface op.
InParallelOpInterface op.
}];
let description = [{
The `parallel_insert_slice` yields a subset tensor value to its parent
ParallelCombiningOpInterface. These subset tensor values are aggregated to
InParallelOpInterface. These subset tensor values are aggregated to
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
ParallelCombiningOpInterface op.
InParallelOpInterface op.

Conflicting writes result in undefined semantics, in that the indices written
to by multiple parallel updates might contain data from any of the updates,
Expand Down Expand Up @@ -1569,8 +1570,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
return ::llvm::cast<RankedTensorType>(getDest().getType());
}

ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
InParallelOpInterface getParallelCombiningParent() {
return dyn_cast<InParallelOpInterface>(
getOperation()->getParentOp());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
LogicalResult verifyInParallelOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

Expand Down
61 changes: 57 additions & 4 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for ops that perform parallel combining operations.
// Defines the interface for ops that perform in parallel combining
// operations.
//
//===----------------------------------------------------------------------===//

Expand All @@ -15,9 +16,9 @@

include "mlir/IR/OpBase.td"

def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
let description = [{
A parallel combining op is an op with a region.
An in parallel op is an op with a region.

This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
Expand Down Expand Up @@ -52,8 +53,60 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
return verifyParallelCombiningOpInterface($_op);
return verifyInParallelOpInterface($_op);
}];
}

def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an operation that models parallel contributions
to result tensors within the context of a parent iterating operation.

This interface is designed for operations that need to coordinate parallel
insertions or contributions to tensors that are being constructed across
multiple parallel iterations. The destination refers to a tensor value that
is assembled by aggregating results from parallel computations; each
parallel iteration may contribute a slice, element, or region to the final
result. No in-place mutation of tensors is implied.

One significant use case for this interface is `tensor.parallel_insert_slice`
which allows parallel insertion of slices that are aggregated into a
destination tensor. With this interface, other operations that express
similar parallel contributions can also be defined.

This op works within an op implementing the `InParallelOpInterface` that
specifies how the parallel results are combined.

Key semantics:
- The operation identifies destination tensors to which iterations
contribute through the `getUpdatedDestinations` method
- Each parallel iteration may produce elements or regions that are
incorporated into the destination tensor
- The parent iterating operation manages the coordination and ensures
proper synchronization of these contributions

Note: This interface does not verify itself, it is up to the implementing operation
to verify the correctness of the op.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the list of destination values this op contributes to.
}],
/*retTy=*/"::mlir::MutableOperandRange",
/*methodName=*/"getUpdatedDestinations",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Returns the iterating parent for this op.
}],
/*retTy=*/"::mlir::Operation*",
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
];
}

#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -4140,12 +4141,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
return DiagnosedSilenceableFailure::success();
}

// If we are inside an InParallel region, temporarily set the insertion point
// outside: only tensor.parallel_insert_slice ops are allowed in there.
if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
rewriter.setInsertionPoint(
target->template getParentOfType<scf::InParallelOp>());
}
// If we are inside a `ParallelCombiningOp` region, temporarily set the
// insertion point outside: only ops implementing ParallelCombiningOpInterface
// are allowed in there.
if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
rewriter.setInsertionPoint(target->getParentOp());

Value extracted = tensor::ExtractSliceOp::create(
rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
Expand Down
52 changes: 30 additions & 22 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
Expand Down Expand Up @@ -681,7 +682,9 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelInsertSliceOp =
cast<tensor::ParallelInsertSliceOp>(yieldingOp);
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
if (!parallelInsertSliceOp)
continue;

Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
Expand Down Expand Up @@ -1439,12 +1442,9 @@ InParallelOp ForallOp::getTerminator() {

SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
InParallelOp inParallelOp = getTerminator();
for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
if (auto parallelInsertSliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
storeOps.push_back(parallelInsertSliceOp);
for (Operation *user : bbArg.getUsers()) {
if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
storeOps.push_back(parallelOp);
}
}
return storeOps;
Expand Down Expand Up @@ -1911,8 +1911,10 @@ struct FoldTensorCastOfOutputIntoForallOp
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
insertSliceOp.getDestMutable().assign(outputBlockArg);
if (auto inParallelOp =
dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
inParallelOp.getUpdatedDestinations().assign(outputBlockArg);
}
}

// Cast results back to the original types.
Expand Down Expand Up @@ -1971,19 +1973,22 @@ LogicalResult InParallelOp::verify() {
if (!forallOp)
return this->emitOpError("expected forall op parent");

// TODO: InParallelOpInterface.
for (Operation &op : getRegion().front().getOperations()) {
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
return this->emitOpError("expected only ")
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&op);
if (!inParallelOp) {
return this->emitOpError("expected only ParallelCombiningOpInterface")
<< " ops";
}

// Verify that inserts are into out block arguments.
Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
MutableOperandRange dests = inParallelOp.getUpdatedDestinations();
ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
if (!llvm::is_contained(regionOutArgs, dest))
return op.emitOpError("may only insert into an output block argument");
for (OpOperand &dest : dests) {
if (!llvm::is_contained(regionOutArgs, dest.get()))
return op.emitOpError("may only insert into an output block argument");
}
}

return success();
}

Expand Down Expand Up @@ -2018,12 +2023,15 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
}

SmallVector<BlockArgument> InParallelOp::getDests() {
return llvm::to_vector<4>(
llvm::map_range(getYieldingOps(), [](Operation &op) {
// Add new ops here as needed.
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
return llvm::cast<BlockArgument>(insertSliceOp.getDest());
}));
SmallVector<BlockArgument> updatedDests;
for (Operation &yieldingOp : getYieldingOps()) {
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
if (!inParallelOp)
continue;
for (OpOperand &updatedOperand : inParallelOp.getUpdatedDestinations())
updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
}
return updatedDests;
}

llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace mlir::bufferization;
namespace {
/// The `scf.forall.in_parallel` terminator is special in a few ways:
/// * It does not implement the BranchOpInterface or
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
/// RegionBranchTerminatorOpInterface, but the InParallelOpInterface
/// which is not supported by BufferDeallocation.
/// * It has a graph-like region which only allows one specific tensor op
/// * After bufferization the nested region is always empty
Expand All @@ -40,9 +40,9 @@ namespace {
/// <implicit in_parallel terminator here>
/// }
/// ```
struct InParallelOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
scf::InParallelOp> {
struct InParallelDeallocOpInterface
: public BufferDeallocationOpInterface::ExternalModel<
InParallelDeallocOpInterface, scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto inParallelOp = cast<scf::InParallelOp>(op);
Expand Down Expand Up @@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
29 changes: 21 additions & 8 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2976,9 +2976,9 @@ class InsertSliceOpConstantArgumentFolder final
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is that the insertion point is just before the ParallelCombiningOp in
// is that the insertion point is just before the InParallelOp in
// the parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
sourceType, toInsert);
Expand Down Expand Up @@ -3153,9 +3153,9 @@ struct InsertSliceOpSourceCastInserter final
// Insert the cast.
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
// that the insertion point is just before the ParallelCombiningOp in the
// that the insertion point is just before the InParallelOp in the
// parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
newSrcType, insertSliceOp.getSource());
Expand Down Expand Up @@ -3846,8 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//

OpResult ParallelInsertSliceOp::getTiedOpResult() {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
Expand Down Expand Up @@ -3901,8 +3900,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}

LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
<< *(getOperation()->getParentOp());

// Verify result type against inferred type.
Expand Down Expand Up @@ -3935,6 +3934,20 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}

// ParallelCombiningOpInterface implementation
MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
return getDestMutable();
}

Operation *ParallelInsertSliceOp::getIteratingParent() {
// Return the parent InParallelOpInterface's parent
if (auto combiningOp =
dyn_cast<InParallelOpInterface>(getOperation()->getParentOp())) {
return combiningOp->getParentOp();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,10 @@ struct ParallelInsertSliceOpInterface
BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
InParallelOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();

// Bufferize the op outside of the parallel combining terminator.
// Bufferize the op outside of the in parallel terminator.
rewriter.setInsertionPoint(parallelCombiningParent);

// Get source and destination buffers.
Expand Down
Loading
Loading