Skip to content

Commit 45af50d

Browse files
authored
[MLIR] Add fusability query to TilingInterface (#166502)
This introduces `isOpFusableWithProducer/Consumer` methods to the TilingInterface that enable querying whether a tilable op can be fused into a given set of producer slices or consumer slice without generating IR. This is needed to enable use of the tiling interface in pattern rewrites, as without this any pattern rewrite that tries to invoke the method to tile is allowed to generate IR and fail.
1 parent b32e067 commit 45af50d

File tree

6 files changed

+285
-4
lines changed

6 files changed

+285
-4
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> {
360360
/*defaultImplementation=*/[{
361361
return failure();
362362
}]
363+
>,
364+
//===------------------------------------------------------------------===//
365+
// Interface methods for querying fusability.
366+
//===------------------------------------------------------------------===//
367+
InterfaceMethod<
368+
/*desc=*/[{
369+
Indicates whether it is possible to fuse this operation with the given
370+
result slice. This method is not allowed to generate any IR.
371+
}],
372+
/*retTy=*/"bool",
373+
/*methodName=*/"isOpFusableWithConsumerSlice",
374+
/*args=*/(ins
375+
"unsigned":$resultNumber,
376+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
377+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes
378+
),
379+
/*methodBody=*/"",
380+
/*defaultImplementation=*/[{
381+
return false;
382+
}]
383+
>,
384+
InterfaceMethod<
385+
/*desc=*/[{
386+
Indicates whether it is possible to fuse this operation with the given
387+
list of operand slices. This method is not allowed to generate any IR.
388+
}],
389+
/*retTy=*/"bool",
390+
/*methodName=*/"isOpFusableWithProducerSlices",
391+
/*args=*/(ins
392+
"::mlir::ArrayRef<unsigned>":$operandNumbers,
393+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
394+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes
395+
),
396+
/*methodBody=*/"",
397+
/*defaultImplementation=*/[{
398+
return false;
399+
}]
363400
>
364401
];
365402
}

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
452452
SmallVector<OpFoldResult> allShapeSizes =
453453
op.createFlatListOfOperandDims(b, op.getLoc());
454454
AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
455-
if (!shapeSizesToLoopsMap)
456-
return failure();
455+
assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap");
457456

458457
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
459458
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ struct LinalgOpTilingInterface
167167
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
168168
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
169169
if (!dimExpr)
170-
continue;
170+
return failure();
171171
unsigned position = dimExpr.getPosition();
172172
auto it = mappedOffsets.find(position);
173173
if (it != mappedOffsets.end()) {
@@ -357,6 +357,32 @@ struct LinalgOpTilingInterface
357357
/// Inline the op payload and store the result.
358358
return inlinePayload(builder, linalgOp, ivs, indexedValues);
359359
}
360+
361+
bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
362+
ArrayRef<OpFoldResult> offsets,
363+
ArrayRef<OpFoldResult> sizes) const {
364+
// The verifier gives all the necessary requirements for consumer fusion.
365+
return true;
366+
}
367+
368+
bool isOpFusableWithProducerSlices(
369+
Operation *op, ArrayRef<unsigned> operandNumbers,
370+
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
371+
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
372+
373+
auto linalgOp = cast<LinalgOp>(op);
374+
SmallVector<AffineMap> indexingMaps =
375+
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
376+
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
377+
return linalgOp.getMatchingIndexingMap(&opOperand);
378+
});
379+
// Check that offsets/sizes are consistent across all operands.
380+
OpBuilder b(op);
381+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
382+
return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
383+
allOffsets, allSizes, mappedOffsets,
384+
mappedSizes));
385+
}
360386
};
361387

362388
//===----------------------------------------------------------------------===//
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
2+
3+
func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
4+
%c0 = arith.constant 0 : index
5+
%c10 = arith.constant 10 : index
6+
%c20 = arith.constant 20 : index
7+
8+
%slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
9+
%slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
10+
11+
// expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}}
12+
%result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
13+
outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
14+
15+
return %result : tensor<100x200xf32>
16+
}
17+
18+
module attributes {transform.with_named_sequence} {
19+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
20+
%add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
21+
transform.test.query_producer_fusability %add : !transform.any_op
22+
transform.yield
23+
}
24+
}
25+
26+
// -----
27+
28+
func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
29+
%c0 = arith.constant 0 : index
30+
%c10 = arith.constant 10 : index
31+
%c20 = arith.constant 20 : index
32+
33+
%slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
34+
%slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
35+
36+
// expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}}
37+
%result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
38+
outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
39+
40+
return %result : tensor<100x200xf32>
41+
}
42+
43+
module attributes {transform.with_named_sequence} {
44+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
45+
%add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
46+
transform.test.query_producer_fusability %add : !transform.any_op
47+
transform.yield
48+
}
49+
}
50+
51+
// -----
52+
53+
func.func @fusable_with_consumer_extract_slice(%arg0: tensor<100x200xf32>, %arg1: tensor<100x200xf32>, %dest: tensor<100x200xf32>) -> tensor<10x20xf32> {
54+
// expected-remark @+1 {{can be fused with consumer tensor.extract_slice op}}
55+
%add = linalg.add ins(%arg0, %arg1 : tensor<100x200xf32>, tensor<100x200xf32>)
56+
outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
57+
58+
%c0 = arith.constant 0 : index
59+
%slice = tensor.extract_slice %add[%c0, %c0] [10, 20] [1, 1] : tensor<100x200xf32> to tensor<10x20xf32>
60+
61+
return %slice : tensor<10x20xf32>
62+
}
63+
64+
module attributes {transform.with_named_sequence} {
65+
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
66+
%add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
67+
transform.test.query_consumer_fusability %add : !transform.any_op
68+
transform.yield
69+
}
70+
}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1616
#include "mlir/Dialect/SCF/IR/SCF.h"
1717
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1819
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
1920
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
2021
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
@@ -683,6 +684,110 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
683684
return DiagnosedSilenceableFailure::success();
684685
}
685686

687+
//===----------------------------------------------------------------------===//
688+
// TestQueryProducerFusability
689+
//===----------------------------------------------------------------------===//
690+
691+
DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply(
692+
TransformRewriter &rewriter, TransformResults &transformResults,
693+
TransformState &state) {
694+
for (Operation *target : state.getPayloadOps(getTarget())) {
695+
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
696+
if (!tilingInterfaceOp) {
697+
return emitSilenceableError()
698+
<< "target operation does not implement TilingInterface";
699+
}
700+
701+
// Collect operand numbers and their corresponding producer insert_slice
702+
// offsets and sizes.
703+
SmallVector<unsigned> operandNumbers;
704+
SmallVector<SmallVector<OpFoldResult>> allOffsets;
705+
SmallVector<SmallVector<OpFoldResult>> allSizes;
706+
707+
for (OpOperand &operand : target->getOpOperands()) {
708+
Value operandValue = operand.get();
709+
Operation *definingOp = operandValue.getDefiningOp();
710+
711+
// Look for a producer tensor.insert_slice. This is only for testing
712+
// purposes and otherwise is not a useful transformation.
713+
if (auto insertSliceOp =
714+
dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
715+
operandNumbers.push_back(operand.getOperandNumber());
716+
allOffsets.push_back(insertSliceOp.getMixedOffsets());
717+
allSizes.push_back(insertSliceOp.getMixedSizes());
718+
}
719+
}
720+
721+
if (!operandNumbers.empty()) {
722+
bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices(
723+
operandNumbers, allOffsets, allSizes);
724+
725+
if (isFusable) {
726+
target->emitRemark()
727+
<< "can be fused with producer tensor.insert_slice ops";
728+
} else {
729+
target->emitRemark()
730+
<< "cannot be fused with producer tensor.insert_slice ops";
731+
}
732+
}
733+
}
734+
735+
return DiagnosedSilenceableFailure::success();
736+
}
737+
738+
void transform::TestQueryProducerFusability::getEffects(
739+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
740+
onlyReadsHandle(getTargetMutable(), effects);
741+
onlyReadsPayload(effects);
742+
}
743+
744+
//===----------------------------------------------------------------------===//
745+
// TestQueryConsumerFusability
746+
//===----------------------------------------------------------------------===//
747+
748+
DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply(
749+
TransformRewriter &rewriter, TransformResults &transformResults,
750+
TransformState &state) {
751+
for (Operation *target : state.getPayloadOps(getTarget())) {
752+
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
753+
if (!tilingInterfaceOp) {
754+
return emitSilenceableError()
755+
<< "target operation does not implement TilingInterface";
756+
}
757+
758+
// Look for tensor.extract_slice ops that consume results of the tilable op.
759+
for (OpResult result : target->getResults()) {
760+
for (OpOperand &use : result.getUses()) {
761+
Operation *user = use.getOwner();
762+
763+
// Look for a consumer tensor.extract_slice. This is only for testing
764+
// purposes and otherwise is not a useful transformation.
765+
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
766+
bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice(
767+
result.getResultNumber(), extractSliceOp.getMixedOffsets(),
768+
extractSliceOp.getMixedSizes());
769+
770+
if (isFusable) {
771+
target->emitRemark()
772+
<< "can be fused with consumer tensor.extract_slice op";
773+
} else {
774+
target->emitRemark()
775+
<< "cannot be fused with consumer tensor.extract_slice op";
776+
}
777+
}
778+
}
779+
}
780+
}
781+
782+
return DiagnosedSilenceableFailure::success();
783+
}
784+
785+
void transform::TestQueryConsumerFusability::getEffects(
786+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
787+
onlyReadsHandle(getTargetMutable(), effects);
788+
onlyReadsPayload(effects);
789+
}
790+
686791
#define GET_OP_CLASSES
687792
#include "TestTilingInterfaceTransformOps.cpp.inc"
688793

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,55 @@ def TestTileUsingCustomLoopOp : Op<
197197
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
198198
let results = (outs TransformHandleTypeInterface:$tiled_ops,
199199
Variadic<TransformHandleTypeInterface>:$loops);
200-
200+
201201
let assemblyFormat = [{
202202
$root_op `tile_sizes` `=` $tile_sizes
203203
attr-dict `:` functional-type(operands, results)
204204
}];
205205
}
206206

207+
def TestQueryProducerFusability : Op<
208+
Transform_Dialect, "test.query_producer_fusability",
209+
[DeclareOpInterfaceMethods<TransformOpInterface>,
210+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
211+
let description = [{
212+
Test operation for the producer fusability query method in the
213+
TilingInterface.
214+
215+
For each operation in the target handle, this looks for tensor.insert_slice
216+
ops that produce operands to the tilable op. The offset/sizes from those
217+
inserts is used as the arguments to `isOpFusableWithProducerSlices` and
218+
emits a remark with the result of the query.
219+
}];
220+
221+
let arguments = (ins TransformHandleTypeInterface:$target);
222+
let results = (outs);
223+
224+
let assemblyFormat = [{
225+
$target attr-dict `:` type($target)
226+
}];
227+
}
228+
229+
def TestQueryConsumerFusability
230+
: Op<Transform_Dialect, "test.query_consumer_fusability",
231+
[DeclareOpInterfaceMethods<TransformOpInterface>,
232+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
233+
let description = [{
234+
Test operation for the consumer fusability query method in the
235+
TilingInterface.
236+
237+
For each operation in the target handle, this looks for tensor.extract_slice
238+
ops that consume results of the tilable op. The offset/sizes from those
239+
extracts is used as the arguments to `isOpFusableWithConsumerSlice` and
240+
emits a remark with the result of the query.
241+
}];
242+
243+
let arguments = (ins TransformHandleTypeInterface:$target);
244+
let results = (outs);
245+
246+
let assemblyFormat = [{
247+
$target attr-dict `:` type($target)
248+
}];
249+
}
250+
207251
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS

0 commit comments

Comments
 (0)