Skip to content

Commit 7a5ab7c

Browse files
author
Aviad Cohen
committed
[mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible
* During operation tiling using scf, we may avoid inserting affine.min to handle the last tile where `upper_bound = step * k` where stride may be a constant or a dynamic.
1 parent d2e9532 commit 7a5ab7c

File tree

2 files changed

+88
-8
lines changed

2 files changed

+88
-8
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
#include "mlir/Dialect/Utils/IndexingUtils.h"
2222
#include "mlir/IR/Dominance.h"
2323
#include "mlir/IR/Matchers.h"
24+
#include "mlir/IR/OpDefinition.h"
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2627
#include "mlir/Interfaces/TilingInterface.h"
2728
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
2829
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30+
#include "llvm/ADT/STLExtras.h"
2931
#include "llvm/ADT/TypeSwitch.h"
3032
#include "llvm/Support/Debug.h"
3133
#include <optional>
@@ -186,18 +188,49 @@ static void checkSafeToTileToForall(TilingInterface op,
186188
}
187189
}
188190

191+
/// Collect divider of the `ofr`.
192+
static void collectDividers(OpFoldResult ofr,
193+
SmallVector<OpFoldResult> &dividers) {
194+
dividers.push_back(ofr);
195+
if (ofr.is<Attribute>())
196+
return;
197+
auto mulOp = cast<Value>(ofr).getDefiningOp<arith::MulIOp>();
198+
if (!mulOp)
199+
return;
200+
201+
// Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`.
202+
collectDividers(mulOp.getLhs(), dividers);
203+
collectDividers(mulOp.getRhs(), dividers);
204+
}
205+
189206
/// Check if `stride` evenly divides the trip count `size - offset`.
190207
static bool tileDividesIterationDomain(Range loopRange) {
208+
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
191209
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
192-
if (!offsetAsInt)
193-
return false;
194210
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
195-
if (!sizeAsInt)
196-
return false;
197-
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
198-
if (!strideAsInt)
199-
return false;
200-
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
211+
if (strideAsInt && offsetAsInt && sizeAsInt)
212+
// `stride`/`size`/`offset` are static, checking (size - offset) % stride =
213+
// 0.
214+
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() ==
215+
0);
216+
217+
// At least `stride`/`size`/`offset` is dynamic.
218+
SmallVector<OpFoldResult> dividersOfSize, dividersOfOffset;
219+
collectDividers(loopRange.size, dividersOfSize);
220+
collectDividers(loopRange.offset, dividersOfOffset);
221+
222+
// Return true if `stride` divides one of the dividers of both `size` and
223+
// `offset`.
224+
auto isStrideDividesDivider = [&](OpFoldResult divider) {
225+
if (!strideAsInt)
226+
// `stride` is dynamic.
227+
return divider == loopRange.stride;
228+
229+
std::optional<int64_t> dividerAsInt = getConstantIntValue(divider);
230+
return dividerAsInt && *dividerAsInt % *strideAsInt == 0;
231+
};
232+
return llvm::any_of(dividersOfSize, isStrideDividesDivider) &&
233+
llvm::any_of(dividersOfOffset, isStrideDividesDivider);
201234
}
202235

203236
/// Returns the bounded tile size given the current `offset`, `loopRange` and

mlir/test/Dialect/Linalg/transform-op-tile.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
266266
-> tensor<128x128xf32>
267267
return %0 : tensor<128x128xf32>
268268
}
269+
270+
// -----
271+
272+
#map = affine_map<(d0) -> (d0)>
273+
274+
// CHECK-LABEL: splited_dynamic_linalg_generic
275+
func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
276+
%c80 = arith.constant 80 : index
277+
%c0 = arith.constant 0 : index
278+
%dim = tensor.dim %arg1, %c0 : tensor<?xi16>
279+
%0 = tensor.empty(%dim) : tensor<?xi16>
280+
%1 = arith.divui %dim, %c80 : index
281+
%2 = arith.muli %1, %c80 : index
282+
%3 = arith.remui %dim, %c80 : index
283+
%extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
284+
%extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
285+
%extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
286+
// CHECK: scf.for
287+
// CHECK-NOT: affine.min
288+
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
289+
^bb0(%in_1: i16, %in_2: i16, %out: i16):
290+
%6 = arith.addi %in_1, %in_2 : i16
291+
linalg.yield %6 : i16
292+
} -> tensor<?xi16>
293+
%inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
294+
%extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
295+
%extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
296+
%extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
297+
// CHECK-NOT: scf.for
298+
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
299+
^bb0(%in_1: i16, %in_2: i16, %out: i16):
300+
%7 = arith.addi %in_1, %in_2 : i16
301+
linalg.yield %7 : i16
302+
} -> tensor<?xi16>
303+
%inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
304+
return %inserted_slice_0 : tensor<?xi16>
305+
}
306+
307+
308+
module attributes {transform.with_named_sequence} {
309+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
310+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
311+
%const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
312+
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
313+
transform.yield
314+
}
315+
}

0 commit comments

Comments
 (0)