-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir] Don't require extract_slice in fusion with transform op #112755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (Max191) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/112755.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ad72b5d7beccde..2bc1d5dde6b5d9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
// evolve into an interface.
+ if (bbArg.getUsers().empty()) {
+ diag.attachNote(containingOp->getLoc())
+ << "could not find fusion opportunity for bbArg: " << bbArg;
+ return {};
+ }
auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
return sliceOp && containingOp->isProperAncestor(sliceOp);
});
-
- // Find a fusion opportunity.
+ OpBuilder::InsertionGuard guard(rewriter);
+ tensor::ExtractSliceOp sliceOpToTile;
if (itBBArgUsers == bbArg.getUsers().end()) {
- diag.attachNote(containingOp->getLoc())
- << "could not find fusion opportunity for bbArg: " << bbArg;
- return {};
+ rewriter.setInsertionPoint(&bbArg.getOwner()->front());
+ } else {
+ sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+ rewriter.setInsertionPoint(sliceOpToTile);
}
- auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
-
- // Try to fuse the producer in-place.
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(sliceOpToTile);
-
- // Replace the use in the tileableProducer before tiling: clone, replace and
- // then tile.
- int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
- LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
// Gather destination tensors.
SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return {};
}
+ // Replace the use in the tileableProducer before tiling: clone, replace and
+ // then tile.
+ SmallVector<Operation *> oldBbArgUsers(bbArg.getUsers());
+ int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
+ LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
IRMapping bvm;
bvm.map(destinationTensors[resultNumber], bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
- auto scopeGuard =
- llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+ // If there was no extract_slice user, then no need to tile.
+ if (!sliceOpToTile) {
+ LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n");
+ // Replace the old uses of bbArg with the cloned op, except for any parallel
+ // insert ops.
+ rewriter.replaceUsesWithIf(
+ bbArg, tileableProducerClone->getResult(resultNumber),
+ [&](OpOperand &operand) {
+ return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner()) &&
+ operand.getOwner() != tileableProducerClone.getOperation();
+ });
+ // Replace the use in containingOp.
+ rewriter.modifyOpInPlace(containingOp, [&]() {
+ containingOp->setOperand(pUse->getOperandNumber(),
+ destinationTensors.front());
+ });
+ return {tileableProducerClone};
+ }
// Tile the producer.
+ auto scopeGuard =
+ llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
FailureOr<TilingResult> tileAndFuseResult =
tileableProducerClone.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..5944968487e2e9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -202,6 +202,46 @@ module {
// -----
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
+ %1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+ // CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]]
+
+ // CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]]
+ %2 = linalg.elemwise_unary ins(%arg1 : tensor<?xf32>) outs(%o : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %1 : tensor<?xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+ // linalg.fill is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
|
Signed-off-by: Max Dawkins <[email protected]>
51ef7f1 to
432341c
Compare
The
transform.structured.fuse_into_containing_optransform op can fuse producers into loops through block arguments, but it relies on having an extract_slice user of the block argument. This PR extends the transform to allow cases where there is no extract_slice user, but there are other users.