diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ad72b5d7beccd..2bc1d5dde6b5d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. + if (bbArg.getUsers().empty()) { + diag.attachNote(containingOp->getLoc()) + << "could not find fusion opportunity for bbArg: " << bbArg; + return {}; + } auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); return sliceOp && containingOp->isProperAncestor(sliceOp); }); - - // Find a fusion opportunity. + OpBuilder::InsertionGuard guard(rewriter); + tensor::ExtractSliceOp sliceOpToTile; if (itBBArgUsers == bbArg.getUsers().end()) { - diag.attachNote(containingOp->getLoc()) - << "could not find fusion opportunity for bbArg: " << bbArg; - return {}; + rewriter.setInsertionPoint(&bbArg.getOwner()->front()); + } else { + sliceOpToTile = llvm::cast(*itBBArgUsers); + rewriter.setInsertionPoint(sliceOpToTile); } - auto sliceOpToTile = cast(*itBBArgUsers); - - // Try to fuse the producer in-place. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sliceOpToTile); - - // Replace the use in the tileableProducer before tiling: clone, replace and - // then tile. - int64_t resultNumber = cast(pUse->get()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. SmallVector destinationTensors; @@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( return {}; } + // Replace the use in the tileableProducer before tiling: clone, replace and + // then tile. + SmallVector oldBbArgUsers(bbArg.getUsers()); + int64_t resultNumber = cast(pUse->get()).getResultNumber(); + LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); IRMapping bvm; bvm.map(destinationTensors[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); - auto scopeGuard = - llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); + + // If there was no extract_slice user, then no need to tile. + if (!sliceOpToTile) { + LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n"); + // Replace the old uses of bbArg with the cloned op, except for any parallel + // insert ops. + rewriter.replaceUsesWithIf( + bbArg, tileableProducerClone->getResult(resultNumber), + [&](OpOperand &operand) { + return !isa(operand.getOwner()) && + operand.getOwner() != tileableProducerClone.getOperation(); + }); + // Replace the use in containingOp. + rewriter.modifyOpInPlace(containingOp, [&]() { + containingOp->setOperand(pUse->getOperandNumber(), + destinationTensors.front()); + }); + return {tileableProducerClone}; + } // Tile the producer. + auto scopeGuard = + llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); FailureOr tileAndFuseResult = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 4115f2857a20c..c0c7b8ec9598b 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -202,6 +202,46 @@ module { // ----- +module { + // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_no_slice + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg1, %c0 : tensor + + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor + // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor) { + %1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor) { + // CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]] + + // CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]] + %2 = linalg.elemwise_unary ins(%arg1 : tensor) outs(%o : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor into tensor + } + } + // CHECK: } + func.return %1 : tensor + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- + #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>