Skip to content

Commit d88f749

Browse files
author
Aviad Cohen
committed
[mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible
* During operation tiling using scf, we may avoid inserting affine.min to handle the last tile where `upper_bound = step * k` where stride may be a constant or a dynamic.
1 parent d2e9532 commit d88f749

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,35 @@ static void checkSafeToTileToForall(TilingInterface op,
186186
}
187187
}
188188

189+
/// Returns true if `size` is dynamic multiplication of `stride`.
190+
/// i.e. , `size = stride * k` where stride may be a constant or a dynamic.
191+
static bool dynamiclyDivisible(OpFoldResult size, OpFoldResult stride) {
192+
Value dynamicSize = dyn_cast_if_present<Value>(size);
193+
if (!dynamicSize)
194+
return false;
195+
auto mulOp = dynamicSize.getDefiningOp<arith::MulIOp>();
196+
if (!mulOp)
197+
return false;
198+
if (Value dynamicStride = dyn_cast_if_present<Value>(stride))
199+
return mulOp.getLhs() == dynamicStride || mulOp.getRhs() == dynamicStride;
200+
std::optional<int64_t> strideAsInt = getConstantIntValue(stride);
201+
std::optional<int64_t> lhsAsInt = getConstantIntValue(mulOp.getLhs());
202+
std::optional<int64_t> rhsAsInt = getConstantIntValue(mulOp.getRhs());
203+
if (strideAsInt && lhsAsInt && *strideAsInt == *lhsAsInt)
204+
return true;
205+
if (strideAsInt && rhsAsInt && *strideAsInt == *rhsAsInt)
206+
return true;
207+
208+
return false;
209+
}
210+
189211
/// Check if `stride` evenly divides the trip count `size - offset`.
190212
static bool tileDividesIterationDomain(Range loopRange) {
191213
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
192214
if (!offsetAsInt)
193215
return false;
216+
if (*offsetAsInt == 0 && dynamiclyDivisible(loopRange.size, loopRange.stride))
217+
return true;
194218
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
195219
if (!sizeAsInt)
196220
return false;

mlir/test/Dialect/Linalg/transform-op-tile.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
266266
-> tensor<128x128xf32>
267267
return %0 : tensor<128x128xf32>
268268
}
269+
270+
// -----
271+
272+
#map = affine_map<(d0) -> (d0)>
273+
274+
// CHECK-LABEL: splited_dynamic_linalg_generic
275+
func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
276+
%c80 = arith.constant 80 : index
277+
%c0 = arith.constant 0 : index
278+
%dim = tensor.dim %arg1, %c0 : tensor<?xi16>
279+
%0 = tensor.empty(%dim) : tensor<?xi16>
280+
%1 = arith.divui %dim, %c80 : index
281+
%2 = arith.muli %1, %c80 : index
282+
%3 = arith.remui %dim, %c80 : index
283+
%extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
284+
%extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
285+
%extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
286+
// CHECK: scf.for
287+
// CHECK-NOT: affine.min
288+
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
289+
^bb0(%in_1: i16, %in_2: i16, %out: i16):
290+
%6 = arith.addi %in_1, %in_2 : i16
291+
linalg.yield %6 : i16
292+
} -> tensor<?xi16>
293+
%inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
294+
%extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
295+
%extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
296+
%extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
297+
// CHECK-NOT: scf.for
298+
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
299+
^bb0(%in_1: i16, %in_2: i16, %out: i16):
300+
%7 = arith.addi %in_1, %in_2 : i16
301+
linalg.yield %7 : i16
302+
} -> tensor<?xi16>
303+
%inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
304+
return %inserted_slice_0 : tensor<?xi16>
305+
}
306+
307+
308+
module attributes {transform.with_named_sequence} {
309+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
310+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
311+
%const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
312+
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
313+
transform.yield
314+
}
315+
}

0 commit comments

Comments
 (0)