Skip to content

Conversation

@Max191
Copy link
Contributor

@Max191 Max191 commented Oct 17, 2024

The transform.structured.fuse_into_containing_op transform op can fuse producers into loops through block arguments, but it relies on having an extract_slice user of the block argument. This PR extends the transform to allow cases where there is no extract_slice user, but there are other users.

@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

The transform.structured.fuse_into_containing_op transform op can fuse producers into loops through block arguments, but it relies on having an extract_slice user of the block argument. This PR extends the transform to allow cases where there is no extract_slice user, but there are other users.


Full diff: https://github.com/llvm/llvm-project/pull/112755.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+37-17)
  • (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+40)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ad72b5d7beccde..2bc1d5dde6b5d9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   // Search the producer slices accessed within the containing operation.
   // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
   // evolve into an interface.
+  if (bbArg.getUsers().empty()) {
+    diag.attachNote(containingOp->getLoc())
+        << "could not find fusion opportunity for bbArg: " << bbArg;
+    return {};
+  }
   auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
     return sliceOp && containingOp->isProperAncestor(sliceOp);
   });
-
-  // Find a fusion opportunity.
+  OpBuilder::InsertionGuard guard(rewriter);
+  tensor::ExtractSliceOp sliceOpToTile;
   if (itBBArgUsers == bbArg.getUsers().end()) {
-    diag.attachNote(containingOp->getLoc())
-        << "could not find fusion opportunity for bbArg: " << bbArg;
-    return {};
+    rewriter.setInsertionPoint(&bbArg.getOwner()->front());
+  } else {
+    sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+    rewriter.setInsertionPoint(sliceOpToTile);
   }
-  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
-
-  // Try to fuse the producer in-place.
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(sliceOpToTile);
-
-  // Replace the use in the tileableProducer before tiling: clone, replace and
-  // then tile.
-  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
-  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
   // Gather destination tensors.
   SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     return {};
   }
 
+  // Replace the use in the tileableProducer before tiling: clone, replace and
+  // then tile.
+  SmallVector<Operation *> oldBbArgUsers(bbArg.getUsers());
+  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
+  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
   IRMapping bvm;
   bvm.map(destinationTensors[resultNumber], bbArg);
   auto tileableProducerClone =
       cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
-  auto scopeGuard =
-      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+  // If there was no extract_slice user, then no need to tile.
+  if (!sliceOpToTile) {
+    LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n");
+    // Replace the old uses of bbArg with the cloned op, except for any parallel
+    // insert ops.
+    rewriter.replaceUsesWithIf(
+        bbArg, tileableProducerClone->getResult(resultNumber),
+        [&](OpOperand &operand) {
+          return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner()) &&
+                 operand.getOwner() != tileableProducerClone.getOperation();
+        });
+    // Replace the use in containingOp.
+    rewriter.modifyOpInPlace(containingOp, [&]() {
+      containingOp->setOperand(pUse->getOperandNumber(),
+                               destinationTensors.front());
+    });
+    return {tileableProducerClone};
+  }
 
   // Tile the producer.
+  auto scopeGuard =
+      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
   FailureOr<TilingResult> tileAndFuseResult =
       tileableProducerClone.generateResultTileValue(
           rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..5944968487e2e9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -202,6 +202,46 @@ module {
 
 // -----
 
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+    // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
+    %1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+      // CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]]
+
+      // CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]]
+      %2 = linalg.elemwise_unary ins(%arg1 : tensor<?xf32>) outs(%o : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %1 : tensor<?xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+      // linalg.fill is tileable. The op is tiled and fused.
+      transform.structured.fuse_into_containing_op %0 into %1
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+        transform.yield
+    }
+  }
+}
+
+// -----
+
 #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
 #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

@Max191 Max191 force-pushed the transform-fusion-with-no-extract branch from 51ef7f1 to 432341c Compare October 17, 2024 18:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants