|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/Utils/StaticValueUtils.h"
|
| 10 | +#include "mlir/IR/Attributes.h" |
10 | 11 | #include "mlir/IR/Matchers.h"
|
11 | 12 | #include "mlir/Support/LLVM.h"
|
12 | 13 | #include "llvm/ADT/APSInt.h"
|
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount(
|
280 | 281 | computeUbMinusLb) {
|
281 | 282 | // This is the bitwidth used to return 0 when loop does not execute.
|
282 | 283 | // We infer it from the type of the bound if it isn't an index type.
|
283 |
| - bool isIndex = true; |
284 |
| - auto getBitwidth = [&](OpFoldResult ofr) -> int { |
285 |
| - if (auto attr = dyn_cast<Attribute>(ofr)) { |
286 |
| - if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
287 |
| - if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) { |
288 |
| - isIndex = intType.isIndex(); |
289 |
| - return intType.getWidth(); |
290 |
| - } |
291 |
| - } |
| 284 | + auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> { |
| 285 | + if (auto intAttr = |
| 286 | + dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) { |
| 287 | + if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) |
| 288 | + return std::make_tuple(intType.getWidth(), intType.isIndex()); |
292 | 289 | } else {
|
293 | 290 | auto val = cast<Value>(ofr);
|
294 |
| - if (auto intType = dyn_cast<IntegerType>(val.getType())) { |
295 |
| - isIndex = intType.isIndex(); |
296 |
| - return intType.getWidth(); |
297 |
| - } |
| 291 | + if (auto intType = dyn_cast<IntegerType>(val.getType())) |
| 292 | + return std::make_tuple(intType.getWidth(), intType.isIndex()); |
298 | 293 | }
|
299 |
| - return IndexType::kInternalStorageBitWidth; |
| 294 | + return std::make_tuple(IndexType::kInternalStorageBitWidth, true); |
300 | 295 | };
|
301 |
| - int bitwidth = getBitwidth(lb); |
302 |
| - assert(bitwidth == getBitwidth(ub) && |
303 |
| - "lb and ub must have the same bitwidth"); |
| 296 | + auto [bitwidth, isIndex] = getBitwidth(lb); |
| 297 | + // This would better be an assert, but unfortunately it breaks scf.for_all |
| 298 | + // which is missing attributes and SSA value optionally for its bounds, and |
| 299 | + // uses Index type for the dynamic bounds but i64 for the static bounds. This |
| 300 | + // is broken... |
| 301 | + if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) { |
| 302 | + LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs " |
| 303 | + << lb; |
| 304 | + return std::nullopt; |
| 305 | + } |
304 | 306 | if (lb == ub)
|
305 | 307 | return APInt(bitwidth, 0);
|
306 | 308 |
|
|
0 commit comments