From 7a5ab7c53ce3d7e9ec1408360e3037758d33d897 Mon Sep 17 00:00:00 2001 From: Aviad Cohen Date: Sun, 27 Oct 2024 14:22:47 +0200 Subject: [PATCH] [mlir][scf]: Avoid inserting affine.min when tiling dynamic operation sizes if possible * During operation tiling using scf, we may avoid inserting affine.min to handle the last tile where `upper_bound = step * k` where stride may be a constant or a dynamic. --- .../SCF/Transforms/TileUsingInterface.cpp | 49 ++++++++++++++++--- .../Dialect/Linalg/transform-op-tile.mlir | 47 ++++++++++++++++++ 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e2feb10b31454..ecb7c265305bd 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -21,11 +21,13 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -186,18 +188,49 @@ static void checkSafeToTileToForall(TilingInterface op, } } +/// Collect divider of the `ofr`. +static void collectDividers(OpFoldResult ofr, + SmallVector ÷rs) { + dividers.push_back(ofr); + if (ofr.is()) + return; + auto mulOp = cast(ofr).getDefiningOp(); + if (!mulOp) + return; + + // Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`. + collectDividers(mulOp.getLhs(), dividers); + collectDividers(mulOp.getRhs(), dividers); +} + /// Check if `stride` evenly divides the trip count `size - offset`. static bool tileDividesIterationDomain(Range loopRange) { + std::optional strideAsInt = getConstantIntValue(loopRange.stride); std::optional offsetAsInt = getConstantIntValue(loopRange.offset); - if (!offsetAsInt) - return false; std::optional sizeAsInt = getConstantIntValue(loopRange.size); - if (!sizeAsInt) - return false; - std::optional strideAsInt = getConstantIntValue(loopRange.stride); - if (!strideAsInt) - return false; - return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); + if (strideAsInt && offsetAsInt && sizeAsInt) + // `stride`/`size`/`offset` are static, checking (size - offset) % stride = + // 0. + return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == + 0); + + // At least `stride`/`size`/`offset` is dynamic. + SmallVector dividersOfSize, dividersOfOffset; + collectDividers(loopRange.size, dividersOfSize); + collectDividers(loopRange.offset, dividersOfOffset); + + // Return true if `stride` divides one of the dividers of both `size` and + // `offset`. + auto isStrideDividesDivider = [&](OpFoldResult divider) { + if (!strideAsInt) + // `stride` is dynamic. + return divider == loopRange.stride; + + std::optional dividerAsInt = getConstantIntValue(divider); + return dividerAsInt && *dividerAsInt % *strideAsInt == 0; + }; + return llvm::any_of(dividersOfSize, isStrideDividesDivider) && + llvm::any_of(dividersOfOffset, isStrideDividesDivider); } /// Returns the bounded tile size given the current `offset`, `loopRange` and diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir index 7bac850d0b7fe..ade523ef378f3 100644 --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -266,3 +266,50 @@ func.func @tile_linalg_matmul( -> tensor<128x128xf32> return %0 : tensor<128x128xf32> } + +// ----- + +#map = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: splited_dynamic_linalg_generic +func.func @splited_dynamic_linalg_generic(%arg0: tensor, %arg1: tensor) -> tensor { + %c80 = arith.constant 80 : index + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg1, %c0 : tensor + %0 = tensor.empty(%dim) : tensor + %1 = arith.divui %dim, %c80 : index + %2 = arith.muli %1, %c80 : index + %3 = arith.remui %dim, %c80 : index + %extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor to tensor + %extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor to tensor + %extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: affine.min + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor, tensor) outs(%extracted_slice_1 : tensor) { + ^bb0(%in_1: i16, %in_2: i16, %out: i16): + %6 = arith.addi %in_1, %in_2 : i16 + linalg.yield %6 : i16 + } -> tensor + %inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor into tensor + %extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor to tensor + %extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor to tensor + %extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor to tensor + // CHECK-NOT: scf.for + %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor, tensor) outs(%extracted_slice_4 : tensor) { + ^bb0(%in_1: i16, %in_2: i16, %out: i16): + %7 = arith.addi %in_1, %in_2 : i16 + linalg.yield %7 : i16 + } -> tensor + %inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor into tensor + return %inserted_slice_0 : tensor +} + + +module attributes {transform.with_named_sequence} { +transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield +} +}