Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def ForallOp : SCF_Op<"forall", [
def InParallelOp : SCF_Op<"forall.in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"ForallOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `forall` block";
Expand All @@ -679,8 +679,6 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
OpBuilder<(ins)>,
];

// TODO: Add a `InParallelOpInterface` interface for ops that can
// appear inside in_parallel.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::BlockArgument> getDests();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
Expand Down
16 changes: 9 additions & 7 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1470,24 +1470,26 @@ def Tensor_PadOp : Tensor_Op<"pad", [
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//

// TODO: Implement InParallelOpInterface.
// TODO: Implement ParallelCombiningOpInterface.
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
["getUpdatedDestinations", "getIteratingParent"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
// HasParent<"InParallelOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
ParallelCombiningOpInterface op.
InParallelOpInterface op.
}];
let description = [{
The `parallel_insert_slice` yields a subset tensor value to its parent
ParallelCombiningOpInterface. These subset tensor values are aggregated to
InParallelOpInterface. These subset tensor values are aggregated to
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
ParallelCombiningOpInterface op.
InParallelOpInterface op.

Conflicting writes result in undefined semantics, in that the indices written
to by multiple parallel updates might contain data from any of the updates,
Expand Down Expand Up @@ -1569,8 +1571,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
return ::llvm::cast<RankedTensorType>(getDest().getType());
}

ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
InParallelOpInterface getParallelCombiningParent() {
return dyn_cast<InParallelOpInterface>(
getOperation()->getParentOp());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
LogicalResult verifyInParallelOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

Expand Down
61 changes: 57 additions & 4 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for ops that perform parallel combining operations.
// Defines the interface for ops that perform in parallel combining
// operations.
//
//===----------------------------------------------------------------------===//

Expand All @@ -15,9 +16,9 @@

include "mlir/IR/OpBase.td"

def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
let description = [{
A parallel combining op is an op with a region.
An in parallel op is an op with a region.

This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
Expand Down Expand Up @@ -52,8 +53,60 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
return verifyParallelCombiningOpInterface($_op);
return verifyInParallelOpInterface($_op);
}];
}

def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an operation performs parallel updates to
destination tensors within the context of a parent iterating operation.

This interface is designed for operations that need to coordinate parallel
insertions or updates to tensors that are being constructed or modified
across multiple parallel iterations. The "updated destination" refers to a
destination tensor that accumulates results from parallel computations,
where each parallel iteration may contribute a slice, element, or region
to the final result.

One significant use case for this interface is `tensor.parallel_insert_slice`
which allows parallel insertion of slices into a destination tensor. But with
this interface, other operations that perform similar parallel updates can
also be defined.

This op works within an op implementing the
`InParallelOpInterface` that specifies how the parallel results are combined.

Key semantics:
- The operation identifies destination tensors that will be updated
through the `getUpdatedDestinations` method
- Each parallel iteration may update elements or regions of the
destination tensor
- The parent iterating operation manages the coordination and ensures
proper synchronization of these updates

Note: This interface does not verify itself, it is up to the implementing operation
to verify the correctness of the op.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the list of values updated by this op.
}],
/*retTy=*/"::mlir::MutableOperandRange",
/*methodName=*/"getUpdatedDestinations",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Returns the iterating parent for this op.
}],
/*retTy=*/"::mlir::Operation*",
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
];
}

#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -4140,11 +4141,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
return DiagnosedSilenceableFailure::success();
}

// If we are inside an InParallel region, temporarily set the insertion point
// outside: only tensor.parallel_insert_slice ops are allowed in there.
if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
rewriter.setInsertionPoint(
target->template getParentOfType<scf::InParallelOp>());
// If we are inside a `ParallelCombiningOp` region, temporarily set the
// insertion point outside: only ops implementing ParallelCombiningOpInterface are
// allowed in there.
if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation())) {
rewriter.setInsertionPoint(target->getParentOp());
}

Value extracted = tensor::ExtractSliceOp::create(
Expand Down
55 changes: 32 additions & 23 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
Expand Down Expand Up @@ -680,8 +681,11 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
// Skip non-ParallelInsertSliceOp operations
auto parallelInsertSliceOp =
cast<tensor::ParallelInsertSliceOp>(yieldingOp);
dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
if (!parallelInsertSliceOp)
continue;

Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
Expand Down Expand Up @@ -1439,12 +1443,9 @@ InParallelOp ForallOp::getTerminator() {

SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
InParallelOp inParallelOp = getTerminator();
for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
if (auto parallelInsertSliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
storeOps.push_back(parallelInsertSliceOp);
for (Operation *user : bbArg.getUsers()) {
if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
storeOps.push_back(parallelOp);
}
}
return storeOps;
Expand Down Expand Up @@ -1673,7 +1674,9 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
if ((result.use_empty() &&
llvm::all_of(forallOp.getCombiningOps(blockArg),
[](Operation *op) { return op->use_empty(); }))) {
resultToDelete.insert(result);
} else {
resultToReplace.push_back(result);
Expand Down Expand Up @@ -1911,8 +1914,9 @@ struct FoldTensorCastOfOutputIntoForallOp
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
insertSliceOp.getDestMutable().assign(outputBlockArg);
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(yieldingOp);
if (inParallelOp)
inParallelOp.getUpdatedDestinations().assign(outputBlockArg);
}

// Cast results back to the original types.
Expand Down Expand Up @@ -1971,19 +1975,21 @@ LogicalResult InParallelOp::verify() {
if (!forallOp)
return this->emitOpError("expected forall op parent");

// TODO: InParallelOpInterface.
for (Operation &op : getRegion().front().getOperations()) {
if (!isa<tensor::ParallelInsertSliceOp>(op)) {
return this->emitOpError("expected only ")
<< tensor::ParallelInsertSliceOp::getOperationName() << " ops";
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&op);
if (!inParallelOp) {
return this->emitOpError("expected only ParallelCombiningOpInterface") << " ops";
}

// Verify that inserts are into out block arguments.
Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
MutableOperandRange dests = inParallelOp.getUpdatedDestinations();
ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
if (!llvm::is_contained(regionOutArgs, dest))
return op.emitOpError("may only insert into an output block argument");
for (OpOperand &dest : dests) {
if (!llvm::is_contained(regionOutArgs, dest.get()))
return op.emitOpError("may only insert into an output block argument");
}
}

return success();
}

Expand Down Expand Up @@ -2018,12 +2024,15 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
}

SmallVector<BlockArgument> InParallelOp::getDests() {
return llvm::to_vector<4>(
llvm::map_range(getYieldingOps(), [](Operation &op) {
// Add new ops here as needed.
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
return llvm::cast<BlockArgument>(insertSliceOp.getDest());
}));
SmallVector<BlockArgument> updatedDests;
for (auto &yieldingOp : getYieldingOps()) {
auto inParallelOp = dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
if (!inParallelOp)
continue;
for (auto &updatedOperand : inParallelOp.getUpdatedDestinations())
updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
}
return updatedDests;
}

llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace mlir::bufferization;
namespace {
/// The `scf.forall.in_parallel` terminator is special in a few ways:
/// * It does not implement the BranchOpInterface or
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
/// RegionBranchTerminatorOpInterface, but the InParallelOpInterface
/// which is not supported by BufferDeallocation.
/// * It has a graph-like region which only allows one specific tensor op
/// * After bufferization the nested region is always empty
Expand All @@ -40,9 +40,9 @@ namespace {
/// <implicit in_parallel terminator here>
/// }
/// ```
struct InParallelOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
scf::InParallelOp> {
struct InParallelDeallocOpInterface
: public BufferDeallocationOpInterface::ExternalModel<
InParallelDeallocOpInterface, scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto inParallelOp = cast<scf::InParallelOp>(op);
Expand Down Expand Up @@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
28 changes: 21 additions & 7 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2976,9 +2976,9 @@ class InsertSliceOpConstantArgumentFolder final
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is that the insertion point is just before the ParallelCombiningOp in
// is that the insertion point is just before the InParallelOp in
// the parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
sourceType, toInsert);
Expand Down Expand Up @@ -3153,9 +3153,9 @@ struct InsertSliceOpSourceCastInserter final
// Insert the cast.
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
// that the insertion point is just before the ParallelCombiningOp in the
// that the insertion point is just before the InParallelOp in the
// parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
newSrcType, insertSliceOp.getSource());
Expand Down Expand Up @@ -3846,7 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//

OpResult ParallelInsertSliceOp::getTiedOpResult() {
ParallelCombiningOpInterface parallelCombiningParent =
InParallelOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Expand Down Expand Up @@ -3901,8 +3901,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}

LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
<< *(getOperation()->getParentOp());

// Verify result type against inferred type.
Expand Down Expand Up @@ -3935,6 +3935,20 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}

// ParallelCombiningOpInterface implementation
MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
return getDestMutable();
}

Operation *ParallelInsertSliceOp::getIteratingParent() {
// Return the parent InParallelOpInterface's parent
if (auto combiningOp = dyn_cast<InParallelOpInterface>(
getOperation()->getParentOp())) {
return combiningOp->getParentOp();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,10 @@ struct ParallelInsertSliceOpInterface
BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
InParallelOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();

// Bufferize the op outside of the parallel combining terminator.
// Bufferize the op outside of the in parallel terminator.
rewriter.setInsertionPoint(parallelCombiningParent);

// Get source and destination buffers.
Expand Down
Loading
Loading