Skip to content

Commit 432341c

Browse files
committed
[mlir] Don't require extract_slice in fusion with transform op
Signed-off-by: Max Dawkins <[email protected]>
1 parent 87645e9 commit 432341c

File tree

2 files changed

+77
-17
lines changed

2 files changed

+77
-17
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
818818
// Search the producer slices accessed within the containing operation.
819819
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
820820
// evolve into an interface.
821+
if (bbArg.getUsers().empty()) {
822+
diag.attachNote(containingOp->getLoc())
823+
<< "could not find fusion opportunity for bbArg: " << bbArg;
824+
return {};
825+
}
821826
auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
822827
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
823828
return sliceOp && containingOp->isProperAncestor(sliceOp);
824829
});
825-
826-
// Find a fusion opportunity.
830+
OpBuilder::InsertionGuard guard(rewriter);
831+
tensor::ExtractSliceOp sliceOpToTile;
827832
if (itBBArgUsers == bbArg.getUsers().end()) {
828-
diag.attachNote(containingOp->getLoc())
829-
<< "could not find fusion opportunity for bbArg: " << bbArg;
830-
return {};
833+
rewriter.setInsertionPoint(&bbArg.getOwner()->front());
834+
} else {
835+
sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
836+
rewriter.setInsertionPoint(sliceOpToTile);
831837
}
832-
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
833-
834-
// Try to fuse the producer in-place.
835-
OpBuilder::InsertionGuard guard(rewriter);
836-
rewriter.setInsertionPoint(sliceOpToTile);
837-
838-
// Replace the use in the tileableProducer before tiling: clone, replace and
839-
// then tile.
840-
int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
841-
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
842838

843839
// Gather destination tensors.
844840
SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
850846
return {};
851847
}
852848

849+
// Replace the use in the tileableProducer before tiling: clone, replace and
850+
// then tile.
851+
SmallVector<Operation *> oldBbArgUsers(bbArg.getUsers());
852+
int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
853+
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
853854
IRMapping bvm;
854855
bvm.map(destinationTensors[resultNumber], bbArg);
855856
auto tileableProducerClone =
856857
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
857-
auto scopeGuard =
858-
llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
858+
859+
// If there was no extract_slice user, then no need to tile.
860+
if (!sliceOpToTile) {
861+
LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n");
862+
// Replace the old uses of bbArg with the cloned op, except for any parallel
863+
// insert ops.
864+
rewriter.replaceUsesWithIf(
865+
bbArg, tileableProducerClone->getResult(resultNumber),
866+
[&](OpOperand &operand) {
867+
return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner()) &&
868+
operand.getOwner() != tileableProducerClone.getOperation();
869+
});
870+
// Replace the use in containingOp.
871+
rewriter.modifyOpInPlace(containingOp, [&]() {
872+
containingOp->setOperand(pUse->getOperandNumber(),
873+
destinationTensors.front());
874+
});
875+
return {tileableProducerClone};
876+
}
859877

860878
// Tile the producer.
879+
auto scopeGuard =
880+
llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
861881
FailureOr<TilingResult> tileAndFuseResult =
862882
tileableProducerClone.generateResultTileValue(
863883
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,46 @@ module {
202202

203203
// -----
204204

205+
module {
206+
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_no_slice
207+
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
208+
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
209+
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
210+
func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
211+
%cst = arith.constant 4.200000e+01 : f32
212+
%c0 = arith.constant 0 : index
213+
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
214+
215+
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
216+
// CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
217+
%1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor<?xf32>) {
218+
// CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]]
219+
220+
// CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]]
221+
%2 = linalg.elemwise_unary ins(%arg1 : tensor<?xf32>) outs(%o : tensor<?xf32>) -> tensor<?xf32>
222+
scf.forall.in_parallel {
223+
tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
224+
}
225+
}
226+
// CHECK: }
227+
func.return %1 : tensor<?xf32>
228+
}
229+
230+
module attributes {transform.with_named_sequence} {
231+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
232+
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
233+
%1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
234+
235+
// linalg.fill is tileable. The op is tiled and fused.
236+
transform.structured.fuse_into_containing_op %0 into %1
237+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
238+
transform.yield
239+
}
240+
}
241+
}
242+
243+
// -----
244+
205245
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
206246
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
207247
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

0 commit comments

Comments
 (0)