Skip to content

Commit d5f0969

Browse files
[mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (#107882)
Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF` looks at operands of tiled/tiled+fused operations to see if they are produced by `extract_slice` operations to populate the worklist used to continue fusion. This implicit assumption does not always work. Instead make the implementations of `getTiledImplementation` return the slices to use to continue fusion. This is a breaking change - To continue to get the same behavior of `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree implementation of `TilingInterface::getTiledImplementation` to return the slices to continue fusion on. All in-tree implementations have been adapted to this. - This change touches parts that required a simplification to the `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object that should be `std::nullopt` if fusion is not to be performed. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 1211d97 commit d5f0969

File tree

10 files changed

+271
-129
lines changed

10 files changed

+271
-129
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
178178
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
179179
/// controls whether to omit the partial/boundary tile condition check in
180180
/// cases where we statically know that it is unnecessary.
181-
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
182-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
183-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
184-
ArrayRef<OpFoldResult> subShapeSizes,
185-
bool omitPartialTileCheck);
181+
Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
182+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
183+
ArrayRef<OpFoldResult> lbs,
184+
ArrayRef<OpFoldResult> ubs,
185+
ArrayRef<OpFoldResult> subShapeSizes,
186+
bool omitPartialTileCheck);
186187

187188
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
188189
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ struct SCFTilingResult {
106106
/// Values to use as replacements for the untiled op. Is the same size as the
107107
/// number of results of the untiled op.
108108
SmallVector<Value> replacements;
109+
/// Slices generated after tiling that can be used for fusing with the tiled
110+
/// producer.
111+
SmallVector<Operation *> generatedSlices;
109112
};
110113

111114
/// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
129132
/// 2) the producer value that is to be fused
130133
/// 3) a boolean value set to `true` if the fusion is from
131134
/// a destination operand.
132-
/// It retuns two booleans
133-
/// - returns `true` if the fusion should be done through the candidate slice
134-
/// - returns `true` if a replacement for the fused producer needs to be
135-
/// yielded from within the tiled loop. Note that it is valid to return
136-
/// `true` only if the slice fused is disjoint across all iterations of the
137-
/// tiled loop. It is up to the caller to ensure that this is true for the
138-
/// fused producers.
139-
using ControlFnTy = std::function<std::tuple<bool, bool>(
135+
/// The control function returns an `std::optiona<ControlFnResult>`.
136+
/// If the return value is `std::nullopt`, that implies no fusion
137+
/// is to be performed along that slice.
138+
struct ControlFnResult {
139+
/// Set to true if the loop nest has to return a replacement value
140+
/// for the fused producer.
141+
bool yieldProducerReplacement = false;
142+
};
143+
using ControlFnTy = std::function<std::optional<ControlFnResult>(
140144
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
141145
bool isDestinationOperand)>;
142-
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
143-
return std::make_tuple(true, false);
146+
/// The default control function implements greedy fusion without yielding
147+
/// a replacement for any of the fused results.
148+
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
149+
bool) -> std::optional<ControlFnResult> {
150+
return ControlFnResult{};
144151
};
145152
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
146153
fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
156163
OpResult origProducer; // Original untiled producer.
157164
Value tiledAndFusedProducer; // Tile and fused producer value.
158165
SmallVector<Operation *> tiledOps;
166+
SmallVector<Operation *> generatedSlices;
159167
};
160168
std::optional<SCFFuseProducerOfSliceResult>
161169
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
@@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
215223
///
216224
/// The @param `yieldResultNumber` decides which result would be yield. If not
217225
/// given, yield all `opResult` of fused producer.
218-
LogicalResult yieldReplacementForFusedProducer(
226+
///
227+
/// The method returns the list of new slices added during the process (which
228+
/// can be used to fuse along).
229+
FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
219230
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
220231
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
221232
MutableArrayRef<LoopLikeOpInterface> loops,

mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ namespace mlir {
2525

2626
/// Container for result values of tiling.
2727
/// - `tiledOps` contains operations created by the tiling implementation that
28-
/// are returned to the caller for further transformations.
28+
/// are returned to the caller for further transformations.
2929
/// - `tiledValues` contains the tiled value corresponding to the result of the
30-
/// untiled operation.
30+
/// untiled operation.
31+
/// - `generatedSlices` contains the list of slices that are generated during
32+
/// tiling. These slices can be used for fusing producers.
3133
struct TilingResult {
3234
SmallVector<Operation *> tiledOps;
3335
SmallVector<Value> tiledValues;
36+
SmallVector<Operation *> generatedSlices;
3437
};
3538

3639
/// Container for the result of merge operation of tiling.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
6767

6868
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
6969
/// `source`.
70-
static Value getSlice(OpBuilder &b, Location loc, Value source,
71-
ArrayRef<OpFoldResult> offsets,
72-
ArrayRef<OpFoldResult> sizes,
73-
ArrayRef<OpFoldResult> strides) {
74-
return TypeSwitch<Type, Value>(source.getType())
75-
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
70+
static Operation *getSlice(OpBuilder &b, Location loc, Value source,
71+
ArrayRef<OpFoldResult> offsets,
72+
ArrayRef<OpFoldResult> sizes,
73+
ArrayRef<OpFoldResult> strides) {
74+
return TypeSwitch<Type, Operation *>(source.getType())
75+
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
7676
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
7777
strides);
7878
})
79-
.Case<MemRefType>([&](MemRefType type) -> Value {
79+
.Case<MemRefType>([&](MemRefType type) -> Operation * {
8080
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
8181
strides);
8282
})
83-
.Default([&](Type t) { return nullptr; });
83+
.Default([&](Type t) -> Operation * { return nullptr; });
8484
}
8585

8686
//===----------------------------------------------------------------------===//
@@ -2634,18 +2634,29 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
26342634
auto oneAttr = builder.getI64IntegerAttr(1);
26352635
SmallVector<OpFoldResult> strides(rank, oneAttr);
26362636
SmallVector<Value> tiledOperands;
2637-
tiledOperands.emplace_back(
2638-
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2639-
tiledOperands.emplace_back(
2640-
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2637+
Operation *inputSlice =
2638+
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2639+
if (!inputSlice) {
2640+
return emitOpError("failed to compute input slice");
2641+
}
2642+
tiledOperands.emplace_back(inputSlice->getResult(0));
2643+
Operation *outputSlice =
2644+
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2645+
if (!outputSlice) {
2646+
return emitOpError("failed to compute output slice");
2647+
}
2648+
tiledOperands.emplace_back(outputSlice->getResult(0));
26412649

26422650
SmallVector<Type, 4> resultTypes;
26432651
if (hasPureTensorSemantics())
26442652
resultTypes.push_back(tiledOperands[1].getType());
26452653
Operation *tiledOp =
26462654
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
26472655

2648-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2656+
return TilingResult{
2657+
{tiledOp},
2658+
SmallVector<Value>(tiledOp->getResults()),
2659+
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
26492660
}
26502661

26512662
LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2992,8 +3003,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
29923003
int64_t filterRank = getFilterOperandRank();
29933004
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
29943005
Location loc = getLoc();
2995-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
2996-
loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
3006+
auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3007+
loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3008+
tiledOperands.emplace_back(filterSlice);
29973009

29983010
SmallVector<OpFoldResult> resultOffsets, resultSizes;
29993011
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3002,15 +3014,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
30023014

30033015
int64_t outputRank = getOutputOperandRank();
30043016
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3005-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3006-
loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3017+
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3018+
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3019+
tiledOperands.emplace_back(outputSlice);
30073020

30083021
SmallVector<Type> resultTypes;
30093022
resultTypes.push_back(tiledOperands[1].getType());
30103023
Operation *tiledOp =
30113024
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
30123025

3013-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3026+
return TilingResult{
3027+
{tiledOp},
3028+
SmallVector<Value>(tiledOp->getResults()),
3029+
llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
30143030
}
30153031

30163032
//===----------------------------------------------------------------------===//
@@ -3159,8 +3175,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31593175
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
31603176
int64_t inputRank = getInputOperandRank();
31613177
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3162-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3163-
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
3178+
auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3179+
loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3180+
tiledOperands.emplace_back(inputSlice);
31643181

31653182
SmallVector<OpFoldResult> resultOffsets, resultSizes;
31663183
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3169,15 +3186,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31693186

31703187
int64_t outputRank = getOutputOperandRank();
31713188
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3172-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3173-
loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3189+
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3190+
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3191+
tiledOperands.emplace_back(outputSlice);
31743192

31753193
SmallVector<Type> resultTypes;
31763194
resultTypes.push_back(tiledOperands[1].getType());
31773195
Operation *tiledOp =
31783196
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
31793197

3180-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3198+
return TilingResult{
3199+
{tiledOp},
3200+
SmallVector<Value>(tiledOp->getResults()),
3201+
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
31813202
}
31823203

31833204
//===----------------------------------------------------------------------===//
@@ -3321,8 +3342,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
33213342
sizes[getValueFDim()]});
33223343
int64_t valueRank = getValueOperandRank();
33233344
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3324-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3325-
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
3345+
auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3346+
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3347+
tiledOperands.emplace_back(valueSlice);
33263348

33273349
SmallVector<OpFoldResult> resultOffsets, resultSizes;
33283350
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3331,15 +3353,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
33313353

33323354
int64_t outputRank = getOutputOperandRank();
33333355
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3334-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3335-
loc, getOutput(), resultOffsets, resultSizes, strides));
3356+
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3357+
loc, getOutput(), resultOffsets, resultSizes, strides);
3358+
tiledOperands.emplace_back(outputSlice);
33363359

33373360
SmallVector<Type> resultTypes;
33383361
resultTypes.push_back(tiledOperands[1].getType());
33393362
Operation *tiledOp =
33403363
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
33413364

3342-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3365+
return TilingResult{
3366+
{tiledOp},
3367+
SmallVector<Value>(tiledOp->getResults()),
3368+
llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
33433369
}
33443370

33453371
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,25 @@ struct LinalgOpTilingInterface
120120
Location loc = op->getLoc();
121121
LinalgOp linalgOp = cast<LinalgOp>(op);
122122
SmallVector<Value> valuesToTile = linalgOp->getOperands();
123-
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
123+
SmallVector<Value> tiledOperands = makeTiledShapes(
124124
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
125+
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
126+
llvm::make_filter_range(
127+
tiledOperands,
128+
[](Value v) -> bool {
129+
return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
130+
v.getDefiningOp());
131+
}),
132+
[](Value v) -> Operation * { return v.getDefiningOp(); });
125133

126134
SmallVector<Type> resultTensorTypes =
127135
getTensorOutputTypes(linalgOp, tiledOperands);
128136

129137
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
130138
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
131139

132-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
140+
return TilingResult{
141+
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
133142
}
134143

135144
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +269,8 @@ struct LinalgOpTilingInterface
260269

261270
return TilingResult{
262271
tilingResult->tiledOps,
263-
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
272+
SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
273+
tilingResult->generatedSlices};
264274
}
265275

266276
/// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +416,12 @@ struct LinalgOpPartialReductionInterface
406416
}
407417

408418
// Step 2a: Extract a slice of the input operands.
409-
SmallVector<Value, 4> tiledInputs = makeTiledShapes(
419+
SmallVector<Value> tiledInputs = makeTiledShapes(
410420
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
421+
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
422+
llvm::make_filter_range(
423+
tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
424+
[](Value v) -> Operation * { return v.getDefiningOp(); });
411425

412426
// Step 2b: Extract a slice of the init operands.
413427
SmallVector<Value, 1> tiledInits;
@@ -424,6 +438,7 @@ struct LinalgOpPartialReductionInterface
424438
auto extractSlice = b.create<tensor::ExtractSliceOp>(
425439
loc, valueToTile, initOffset, initSizes, initStride);
426440
tiledInits.push_back(extractSlice);
441+
generatedSlices.push_back(extractSlice);
427442
}
428443

429444
// Update the indexing maps.
@@ -453,7 +468,8 @@ struct LinalgOpPartialReductionInterface
453468
return TilingResult{
454469
{genericOp.getOperation()},
455470
llvm::map_to_vector(genericOp->getResults(),
456-
[](OpResult r) -> Value { return r; })};
471+
[](OpResult r) -> Value { return r; }),
472+
generatedSlices};
457473
}
458474

459475
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
565565
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
566566
}
567567

568-
static Value materializeTiledShape(OpBuilder &builder, Location loc,
569-
Value valueToTile,
570-
const SliceParameters &sliceParams) {
568+
static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
569+
Value valueToTile,
570+
const SliceParameters &sliceParams) {
571571
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
572572
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
573573
.Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
583583
.Default([](ShapedType) -> Operation * {
584584
llvm_unreachable("Unexpected shaped type");
585585
});
586-
return sliceOp->getResult(0);
586+
return sliceOp;
587587
}
588588

589-
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
590-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
591-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
592-
ArrayRef<OpFoldResult> subShapeSizes,
593-
bool omitPartialTileCheck) {
589+
Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
590+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
591+
ArrayRef<OpFoldResult> lbs,
592+
ArrayRef<OpFoldResult> ubs,
593+
ArrayRef<OpFoldResult> subShapeSizes,
594+
bool omitPartialTileCheck) {
594595
SliceParameters sliceParams =
595596
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
596597
ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
841842
tiledShapes.push_back(
842843
sliceParams.has_value()
843844
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
845+
->getResult(0)
844846
: valueToTile);
845847
}
846848
return tiledShapes;

0 commit comments

Comments
 (0)