|
21 | 21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | 22 | #include "mlir/IR/Dominance.h" |
23 | 23 | #include "mlir/IR/Matchers.h" |
| 24 | +#include "mlir/IR/OpDefinition.h" |
24 | 25 | #include "mlir/IR/PatternMatch.h" |
25 | 26 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
26 | 27 | #include "mlir/Interfaces/TilingInterface.h" |
27 | 28 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
28 | 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 30 | +#include "llvm/ADT/STLExtras.h" |
29 | 31 | #include "llvm/ADT/TypeSwitch.h" |
30 | 32 | #include "llvm/Support/Debug.h" |
31 | 33 | #include <optional> |
@@ -186,18 +188,49 @@ static void checkSafeToTileToForall(TilingInterface op, |
186 | 188 | } |
187 | 189 | } |
188 | 190 |
|
| 191 | +/// Collect divider of the `ofr`. |
| 192 | +static void collectDividers(OpFoldResult ofr, |
| 193 | + SmallVector<OpFoldResult> ÷rs) { |
| 194 | + dividers.push_back(ofr); |
| 195 | + if (ofr.is<Attribute>()) |
| 196 | + return; |
| 197 | + auto mulOp = cast<Value>(ofr).getDefiningOp<arith::MulIOp>(); |
| 198 | + if (!mulOp) |
| 199 | + return; |
| 200 | + |
| 201 | + // Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`. |
| 202 | + collectDividers(mulOp.getLhs(), dividers); |
| 203 | + collectDividers(mulOp.getRhs(), dividers); |
| 204 | +} |
| 205 | + |
189 | 206 | /// Check if `stride` evenly divides the trip count `size - offset`. |
190 | 207 | static bool tileDividesIterationDomain(Range loopRange) { |
| 208 | + std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); |
191 | 209 | std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); |
192 | | - if (!offsetAsInt) |
193 | | - return false; |
194 | 210 | std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); |
195 | | - if (!sizeAsInt) |
196 | | - return false; |
197 | | - std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); |
198 | | - if (!strideAsInt) |
199 | | - return false; |
200 | | - return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); |
| 211 | + if (strideAsInt && offsetAsInt && sizeAsInt) |
| 212 | + // `stride`/`size`/`offset` are static, checking (size - offset) % stride = |
| 213 | + // 0. |
| 214 | + return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == |
| 215 | + 0); |
| 216 | + |
| 217 | + // At least `stride`/`size`/`offset` is dynamic. |
| 218 | + SmallVector<OpFoldResult> dividersOfSize, dividersOfOffset; |
| 219 | + collectDividers(loopRange.size, dividersOfSize); |
| 220 | + collectDividers(loopRange.offset, dividersOfOffset); |
| 221 | + |
| 222 | + // Return true if `stride` divides one of the dividers of both `size` and |
| 223 | + // `offset`. |
| 224 | + auto isStrideDividesDivider = [&](OpFoldResult divider) { |
| 225 | + if (!strideAsInt) |
| 226 | + // `stride` is dynamic. |
| 227 | + return divider == loopRange.stride; |
| 228 | + |
| 229 | + std::optional<int64_t> dividerAsInt = getConstantIntValue(divider); |
| 230 | + return dividerAsInt && *dividerAsInt % *strideAsInt == 0; |
| 231 | + }; |
| 232 | + return llvm::any_of(dividersOfSize, isStrideDividesDivider) && |
| 233 | + llvm::any_of(dividersOfOffset, isStrideDividesDivider); |
201 | 234 | } |
202 | 235 |
|
203 | 236 | /// Returns the bounded tile size given the current `offset`, `loopRange` and |
|
0 commit comments