Skip to content

Commit 0c37ff2

Browse files
committed
Check for monotonic functions
1 parent 9406825 commit 0c37ff2

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,24 @@ namespace {
5656
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757
//
5858
struct TileCheck : public AffineExprVisitor<TileCheck> {
59-
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
60-
: tileSizes(tileSizes), sizeBounds(sizeBounds) {}
59+
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds,
60+
bool isMonotonicallyIncreasing)
61+
: tileSizes(tileSizes), sizeBounds(sizeBounds),
62+
isMonotonicallyIncreasing(isMonotonicallyIncreasing) {}
6163

6264
void visitDimExpr(AffineDimExpr expr) {
6365
unsigned pos = expr.getPosition();
6466

65-
// This dimension is tiled if the tile size is larger than zero and not
66-
// equal to its domain size (if statically known).
67-
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
68-
if (tileSize && !sizeBounds.empty()) {
69-
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
70-
if (sizeBound && *sizeBound == *tileSize) {
71-
return;
67+
// If the expression is non monotonic, this dimension is tiled if the tile
68+
// size is larger than zero and not equal to its domain size (if statically
69+
// known).
70+
if (!isMonotonicallyIncreasing) {
71+
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
72+
if (tileSize && !sizeBounds.empty()) {
73+
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
74+
if (sizeBound && *sizeBound == *tileSize) {
75+
return;
76+
}
7277
}
7378
}
7479

@@ -84,6 +89,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
8489
bool isTiled = false;
8590
ArrayRef<OpFoldResult> tileSizes;
8691
ArrayRef<OpFoldResult> sizeBounds;
92+
bool isMonotonicallyIncreasing;
8793
};
8894

8995
} // namespace
@@ -92,7 +98,7 @@ static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
9298
ArrayRef<OpFoldResult> sizeBounds) {
9399
if (!expr)
94100
return false;
95-
TileCheck t(tileSizes, sizeBounds);
101+
TileCheck t(tileSizes, sizeBounds, expr.isMonotonicallyIncreasing());
96102
t.visit(expr);
97103
return t.isTiled;
98104
}

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,44 @@ module attributes {transform.with_named_sequence} {
199199
transform.yield
200200
}
201201
}
202+
203+
// -----
204+
205+
#identity = affine_map<(d0, d1) -> (d0, d1)>
206+
#identity1 = affine_map<(d0, d1) -> (d0 mod 3, d1)>
207+
208+
// CHECK-LABEL: func @tile_monotonic_outer_dim
209+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x10xf32>
210+
func.func @tile_monotonic_outer_dim(%in: tensor<4x10xf32>) -> tensor<4x10xf32> {
211+
%empty = tensor.empty() : tensor<4x10xf32>
212+
%1 = linalg.generic {indexing_maps = [#identity, #identity1], iterator_types = ["parallel", "parallel"]}
213+
ins(%in : tensor<4x10xf32>) outs(%empty : tensor<4x10xf32>) {
214+
^bb1(%a: f32, %b: f32):
215+
linalg.yield %a : f32
216+
} -> tensor<4x10xf32>
217+
218+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
219+
// CHECK: %[[C4_1:.+]] = arith.constant 4 : index
220+
// CHECK: %[[C5:.+]] = arith.constant 5 : index
221+
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[C4]] step %[[C4_1]] iter_args(%[[ARG1:.+]] = %[[OUT:.+]]) -> (tensor<4x10xf32>) {
222+
// CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<4x10xf32>) {
223+
// CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32>
224+
// CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32>
225+
// CHECK: %[[RES:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[INSLICE]] : tensor<4x5xf32>) outs(%[[OUTSLICE]] : tensor<4x5xf32>) {
226+
// CHECK: ^bb0(%in: f32, %out: f32):
227+
// CHECK: linalg.yield %in : f32
228+
// CHECK: } -> tensor<4x5xf32>
229+
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[RES]] into %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<4x10xf32>
230+
// CHECK: scf.yield %[[INSERT_SLICE]] : tensor<4x10xf32>
231+
// CHECK: }
232+
233+
return %1 : tensor<4x10xf32>
234+
}
235+
236+
module attributes {transform.with_named_sequence} {
237+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
238+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
239+
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [4, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
240+
transform.yield
241+
}
242+
}

0 commit comments

Comments
 (0)