diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 88df54174da24..d3c01c31636a7 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -654,7 +654,7 @@ def ForallOp : SCF_Op<"forall", [ def InParallelOp : SCF_Op<"forall.in_parallel", [ Pure, Terminator, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, HasParent<"ForallOp">, ] # GraphRegionNoTerminator.traits> { let summary = "terminates a `forall` block"; @@ -679,8 +679,6 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [ OpBuilder<(ins)>, ]; - // TODO: Add a `InParallelOpInterface` interface for ops that can - // appear inside in_parallel. let extraClassDeclaration = [{ ::llvm::SmallVector<::mlir::BlockArgument> getDests(); ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 7d396e5c64c28..2453cf5b5b5a4 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1470,24 +1470,25 @@ def Tensor_PadOp : Tensor_Op<"pad", [ // ParallelInsertSliceOp //===----------------------------------------------------------------------===// -// TODO: Implement InParallelOpInterface. def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, + DeclareOpInterfaceMethods, // TODO: Cannot use an interface here atm, verify this manually for now. - // HasParent<"ParallelCombiningOpInterface"> + // HasParent<"InParallelOpInterface"> ]> { let summary = [{ Specify the tensor slice update of a single thread of a parent - ParallelCombiningOpInterface op. + InParallelOpInterface op. }]; let description = [{ The `parallel_insert_slice` yields a subset tensor value to its parent - ParallelCombiningOpInterface. These subset tensor values are aggregated to + InParallelOpInterface. These subset tensor values are aggregated to in some unspecified order into a full tensor value returned by the parent parallel iterating op. The `parallel_insert_slice` is one such op allowed in the - ParallelCombiningOpInterface op. + InParallelOpInterface op. Conflicting writes result in undefined semantics, in that the indices written to by multiple parallel updates might contain data from any of the updates, @@ -1569,8 +1570,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ return ::llvm::cast(getDest().getType()); } - ParallelCombiningOpInterface getParallelCombiningParent() { - return dyn_cast( + InParallelOpInterface getParallelCombiningParent() { + return dyn_cast( getOperation()->getParentOp()); } diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h index 72db06163df37..82ab427699f64 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -19,7 +19,7 @@ namespace mlir { namespace detail { // TODO: Single region single block interface on interfaces ? -LogicalResult verifyParallelCombiningOpInterface(Operation *op); +LogicalResult verifyInParallelOpInterface(Operation *op); } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td index 424b4cf0a0a58..ace26f723ef53 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// Defines the interface for ops that perform parallel combining operations. +// Defines the interface for ops that perform in parallel combining +// operations. // //===----------------------------------------------------------------------===// @@ -15,9 +16,9 @@ include "mlir/IR/OpBase.td" -def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { +def InParallelOpInterface : OpInterface<"InParallelOpInterface"> { let description = [{ - A parallel combining op is an op with a region. + An in parallel op is an op with a region. This is useful as a terminator to parallel operations that iterate over some set and return tensors while avoiding tight coupling between the @@ -52,8 +53,60 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { ]; // TODO: Single region single block interface on interfaces ? let verify = [{ - return verifyParallelCombiningOpInterface($_op); + return verifyInParallelOpInterface($_op); + }]; +} + +def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { + let description = [{ + A parallel combining op is an operation that models parallel contributions + to result tensors within the context of a parent iterating operation. + + This interface is designed for operations that need to coordinate parallel + insertions or contributions to tensors that are being constructed across + multiple parallel iterations. The destination refers to a tensor value that + is assembled by aggregating results from parallel computations; each + parallel iteration may contribute a slice, element, or region to the final + result. No in-place mutation of tensors is implied. + + One significant use case for this interface is `tensor.parallel_insert_slice` + which allows parallel insertion of slices that are aggregated into a + destination tensor. With this interface, other operations that express + similar parallel contributions can also be defined. + + This op works within an op implementing the `InParallelOpInterface` that + specifies how the parallel results are combined. + + Key semantics: + - The operation identifies destination tensors to which iterations + contribute through the `getUpdatedDestinations` method + - Each parallel iteration may produce elements or regions that are + incorporated into the destination tensor + - The parent iterating operation manages the coordination and ensures + proper synchronization of these contributions + + Note: This interface does not verify itself, it is up to the implementing operation + to verify the correctness of the op. }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns the list of destination values this op contributes to. + }], + /*retTy=*/"::mlir::MutableOperandRange", + /*methodName=*/"getUpdatedDestinations", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns the iterating parent for this op. + }], + /*retTy=*/"::mlir::Operation*", + /*methodName=*/"getIteratingParent", + /*args=*/(ins) + >, + ]; } #endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f0c1f4485b054..7f8d45c237765 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -36,6 +36,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -4140,12 +4141,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, return DiagnosedSilenceableFailure::success(); } - // If we are inside an InParallel region, temporarily set the insertion point - // outside: only tensor.parallel_insert_slice ops are allowed in there. - if constexpr (std::is_same_v) { - rewriter.setInsertionPoint( - target->template getParentOfType()); - } + // If we are inside a `ParallelCombiningOp` region, temporarily set the + // insertion point outside: only ops implementing ParallelCombiningOpInterface + // are allowed in there. + if (isa(target.getOperation())) + rewriter.setInsertionPoint(target->getParentOp()); Value extracted = tensor::ExtractSliceOp::create( rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(), diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 84f9777a443fd..45b14fcf8aadd 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" @@ -681,7 +682,9 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { results.reserve(forallOp.getResults().size()); for (auto &yieldingOp : terminator.getYieldingOps()) { auto parallelInsertSliceOp = - cast(yieldingOp); + dyn_cast(yieldingOp); + if (!parallelInsertSliceOp) + continue; Value dst = parallelInsertSliceOp.getDest(); Value src = parallelInsertSliceOp.getSource(); @@ -1439,12 +1442,9 @@ InParallelOp ForallOp::getTerminator() { SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { SmallVector storeOps; - InParallelOp inParallelOp = getTerminator(); - for (Operation &yieldOp : inParallelOp.getYieldingOps()) { - if (auto parallelInsertSliceOp = - dyn_cast(yieldOp); - parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { - storeOps.push_back(parallelInsertSliceOp); + for (Operation *user : bbArg.getUsers()) { + if (auto parallelOp = dyn_cast(user)) { + storeOps.push_back(parallelOp); } } return storeOps; @@ -1911,8 +1911,10 @@ struct FoldTensorCastOfOutputIntoForallOp auto terminator = newForallOp.getTerminator(); for (auto [yieldingOp, outputBlockArg] : llvm::zip( terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { - auto insertSliceOp = cast(yieldingOp); - insertSliceOp.getDestMutable().assign(outputBlockArg); + if (auto parallelCombingingOp = + dyn_cast(yieldingOp)) { + parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg); + } } // Cast results back to the original types. @@ -1971,19 +1973,22 @@ LogicalResult InParallelOp::verify() { if (!forallOp) return this->emitOpError("expected forall op parent"); - // TODO: InParallelOpInterface. for (Operation &op : getRegion().front().getOperations()) { - if (!isa(op)) { - return this->emitOpError("expected only ") - << tensor::ParallelInsertSliceOp::getOperationName() << " ops"; + auto parallelCombiningOp = dyn_cast(&op); + if (!parallelCombiningOp) { + return this->emitOpError("expected only ParallelCombiningOpInterface") + << " ops"; } // Verify that inserts are into out block arguments. - Value dest = cast(op).getDest(); + MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations(); ArrayRef regionOutArgs = forallOp.getRegionOutArgs(); - if (!llvm::is_contained(regionOutArgs, dest)) - return op.emitOpError("may only insert into an output block argument"); + for (OpOperand &dest : dests) { + if (!llvm::is_contained(regionOutArgs, dest.get())) + return op.emitOpError("may only insert into an output block argument"); + } } + return success(); } @@ -2018,12 +2023,17 @@ OpResult InParallelOp::getParentResult(int64_t idx) { } SmallVector InParallelOp::getDests() { - return llvm::to_vector<4>( - llvm::map_range(getYieldingOps(), [](Operation &op) { - // Add new ops here as needed. - auto insertSliceOp = cast(&op); - return llvm::cast(insertSliceOp.getDest()); - })); + SmallVector updatedDests; + for (Operation &yieldingOp : getYieldingOps()) { + auto parallelCombiningOp = + dyn_cast(&yieldingOp); + if (!parallelCombiningOp) + continue; + for (OpOperand &updatedOperand : + parallelCombiningOp.getUpdatedDestinations()) + updatedDests.push_back(cast(updatedOperand.get())); + } + return updatedDests; } llvm::iterator_range InParallelOp::getYieldingOps() { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp index a44612410bdee..63216e7cc7fba 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -16,7 +16,7 @@ using namespace mlir::bufferization; namespace { /// The `scf.forall.in_parallel` terminator is special in a few ways: /// * It does not implement the BranchOpInterface or -/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface +/// RegionBranchTerminatorOpInterface, but the InParallelOpInterface /// which is not supported by BufferDeallocation. /// * It has a graph-like region which only allows one specific tensor op /// * After bufferization the nested region is always empty @@ -40,9 +40,9 @@ namespace { /// /// } /// ``` -struct InParallelOpInterface - : public BufferDeallocationOpInterface::ExternalModel { +struct InParallelDeallocOpInterface + : public BufferDeallocationOpInterface::ExternalModel< + InParallelDeallocOpInterface, scf::InParallelOp> { FailureOr process(Operation *op, DeallocationState &state, const DeallocationOptions &options) const { auto inParallelOp = cast(op); @@ -75,7 +75,7 @@ struct ReduceReturnOpInterface void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) { - InParallelOp::attachInterface(*ctx); + InParallelOp::attachInterface(*ctx); ReduceReturnOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 68584ec4fd814..fa97b49a41d97 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2976,9 +2976,9 @@ class InsertSliceOpConstantArgumentFolder final if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); // The only difference between InsertSliceOp and ParallelInsertSliceOp - // is that the insertion point is just before the ParallelCombiningOp in + // is that the insertion point is just before the InParallelOp in // the parallel case. - if (std::is_same::value) + if (isa(insertSliceOp->getParentOp())) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(), sourceType, toInsert); @@ -3153,9 +3153,9 @@ struct InsertSliceOpSourceCastInserter final // Insert the cast. OpBuilder::InsertionGuard g(rewriter); // The only difference between InsertSliceOp and ParallelInsertSliceOp is - // that the insertion point is just before the ParallelCombiningOp in the + // that the insertion point is just before the InParallelOp in the // parallel case. - if (std::is_same::value) + if (isa(insertSliceOp->getParentOp())) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); @@ -3846,8 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) { //===----------------------------------------------------------------------===// OpResult ParallelInsertSliceOp::getTiedOpResult() { - ParallelCombiningOpInterface parallelCombiningParent = - getParallelCombiningParent(); + InParallelOpInterface parallelCombiningParent = getParallelCombiningParent(); for (const auto &it : llvm::enumerate(parallelCombiningParent.getYieldingOps())) { Operation &nextOp = it.value(); @@ -3901,8 +3900,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, } LogicalResult ParallelInsertSliceOp::verify() { - if (!isa(getOperation()->getParentOp())) - return this->emitError("expected ParallelCombiningOpInterface parent, got:") + if (!isa(getOperation()->getParentOp())) + return this->emitError("expected InParallelOpInterface parent, got:") << *(getOperation()->getParentOp()); // Verify result type against inferred type. @@ -3935,6 +3934,19 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() { return ::getDroppedDims(getSourceType().getShape(), getMixedSizes()); } +// ParallelCombiningOpInterface implementation. +MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() { + return getDestMutable(); +} + +Operation *ParallelInsertSliceOp::getIteratingParent() { + // Return the parent InParallelOpInterface's parent. + if (auto combiningOp = + dyn_cast(getOperation()->getParentOp())) + return combiningOp->getParentOp(); + return nullptr; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c3356c1e4b9d8..bce964e47a3be 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -970,10 +970,10 @@ struct ParallelInsertSliceOpInterface BufferizationState &state) const { OpBuilder::InsertionGuard g(rewriter); auto parallelInsertSliceOp = cast(op); - ParallelCombiningOpInterface parallelCombiningParent = + InParallelOpInterface parallelCombiningParent = parallelInsertSliceOp.getParallelCombiningParent(); - // Bufferize the op outside of the parallel combining terminator. + // Bufferize the op outside of the in parallel terminator. rewriter.setInsertionPoint(parallelCombiningParent); // Get source and destination buffers. diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index d76c02af7ab16..b32faf481af80 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -215,12 +215,11 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern { sourceInsertSliceOp.getMixedSizes(), droppedDims, resolvedSizes); - // If we are inside an InParallel region, temporarily set the insertion - // point outside: only tensor.parallel_insert_slice ops are allowed in - // there. - if (std::is_same_v) { - rewriter.setInsertionPoint( - insertSliceOp->template getParentOfType()); + // If we are inside a ParallelCombining region, temporarily set the + // insertion point outside: only ops of ParallelCombiningOpInterface are + // allowed in there. + if (isa(insertSliceOp.getOperation())) { + rewriter.setInsertionPoint(insertSliceOp->getParentOp()); } // Resolve offsets according to source offsets and strides. diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp index 2b6703543bbd3..30b8191bf34b0 100644 --- a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp +++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp @@ -11,11 +11,11 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ParallelCombiningOpInterface +// InParallelOpInterface (formerly ParallelCombiningOpInterface) //===----------------------------------------------------------------------===// // TODO: Single region single block interface on interfaces ? -LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) { +LogicalResult mlir::detail::verifyInParallelOpInterface(Operation *op) { if (op->getNumRegions() != 1) return op->emitError("expected single region op"); if (!op->getRegion(0).hasOneBlock()) diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 5f42938244db6..9005110205630 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -915,7 +915,7 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te // ----- -func.func @reduce_dispatch_0() -> tensor<4x2xf32> { +func.func @parallel_insert_slice() -> tensor<4x2xf32> { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 @@ -923,6 +923,7 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> { %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) { %1 = tensor.empty() : tensor<1x1xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> + // CHECK: scf.forall.in_parallel scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}} // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<4x2xf32> diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index bb7958083e55c..37fc86b18e7f0 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -645,7 +645,7 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) { %result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> - // expected-error @+1 {{expected only tensor.parallel_insert_slice ops}} + // expected-error @+1 {{expected only ParallelCombiningOpInterface ops}} scf.forall.in_parallel { tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir index 9bb87ffbb2090..ed3685514dd0d 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -908,3 +908,111 @@ func.func @parallel_region_no_read() } return } + +// ----- + +// CHECK-LABEL: func @in_order_multiple_parallel_writes +func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32> {bufferization.writable = true}, + %3: tensor<320xf32> {bufferization.writable = true}) + -> (tensor<320xf32>, tensor<320xf32>) +{ + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) { + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %7 = tensor.extract_slice %arg2[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32> + + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32> +} + +// ----- + +// CHECK-LABEL: func @out_of_order_parallel_write +func.func @out_of_order_parallel_write(%2: tensor<320xf32> {bufferization.writable = true}, + %3: tensor<320xf32> {bufferization.writable = true}) + -> (tensor<320xf32>, tensor<320xf32>) +{ + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) { + // The extract_slice cannot operate in place because it is used after the + // first write. + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + + // Additionally the fill aliases the thread local slice. + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32> + + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %7 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32> +} + +// ----- + +// CHECK-LABEL: func @out_of_order_parallel_write +func.func @out_of_order_parallel_write_multiple_reads(%2: tensor<320xf32> {bufferization.writable = true}, + %3: tensor<320xf32> {bufferization.writable = true}) + -> (tensor<320xf32>, tensor<320xf32>) +{ + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) { + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["false", "none"]} + %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32> + + %reverse = arith.subi %c320, %arg0 : index + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %8 = tensor.extract_slice %arg1[%reverse] [1] [1] : tensor<320xf32> to tensor<1xf32> + scf.forall.in_parallel { + // Also cannot operate in place due to subsequent conflicting reads. + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %7 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %8 into %arg2[%reverse] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32> +} +// ----- + +// CHECK-LABEL: func @in_order_multiple_parallel_writes +func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32> {bufferization.writable = true}) + -> (tensor<320xf32>) +{ + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2) -> (tensor<320xf32>) { + // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + %reverse = arith.subi %c320, %arg0 : index + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg1[%reverse] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + return %4 : tensor<320xf32> +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir index 8f4b924cfd3cc..92486b8ed7208 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -112,7 +112,7 @@ func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>, // CHECK-SAME: %[[arg0:.*]]: tensor<100xf32>, %[[arg1:.*]]: tensor<100xf32> // CHECK-FUNC-LABEL: func @scf_forall_out_of_place( func.func @scf_forall_out_of_place(%in: tensor<100xf32>, - %out: tensor<100xf32>) { + %out: tensor<100xf32>) { %c1 = arith.constant 1 : index %num_threads = arith.constant 100 : index @@ -132,3 +132,31 @@ func.func @scf_forall_out_of_place(%in: tensor<100xf32>, } {mapping = [#gpu.thread]} return } + +// ----- + +// CHECK-LABEL: func @in_order_multiple_parallel_writes +func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32>, + %3: tensor<320xf32>) + -> (tensor<320xf32>, tensor<320xf32>) +{ + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) { + // CHECK: tensor.extract_slice {{.*}} + %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + // CHECK: tensor.extract_slice {{.*}} + %7 = tensor.extract_slice %arg2[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32> + // CHECK: linalg.fill {{.*}} + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32> + + // CHECK: tensor.parallel_insert_slice {{.*}} + // CHECK: tensor.parallel_insert_slice {{.*}} + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index cd32e98dac693..540c8b85ecaa4 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10847,6 +10847,7 @@ cc_library( ":LinalgTransformOpsIncGen", ":LinalgTransforms", ":LinalgUtils", + ":ParallelCombiningOpInterface", ":SCFDialect", ":SCFTransforms", ":Support",