Skip to content

Commit 385c9f5

Browse files
authored
[MLIR] Cleanup constantTripCount() (NFC) (#159307)
Add post-merge review comments on #158679
1 parent 01ee9fe commit 385c9f5

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
483483
std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
484484
if (!mayBeConstantTripCount.has_value())
485485
return failure();
486-
APInt &tripCount = *mayBeConstantTripCount;
486+
const APInt &tripCount = *mayBeConstantTripCount;
487487
if (tripCount.isZero())
488488
return success();
489489
if (tripCount.getSExtValue() == 1)

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/StaticValueUtils.h"
10+
#include "mlir/IR/Attributes.h"
1011
#include "mlir/IR/Matchers.h"
1112
#include "mlir/Support/LLVM.h"
1213
#include "llvm/ADT/APSInt.h"
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount(
280281
computeUbMinusLb) {
281282
// This is the bitwidth used to return 0 when loop does not execute.
282283
// We infer it from the type of the bound if it isn't an index type.
283-
bool isIndex = true;
284-
auto getBitwidth = [&](OpFoldResult ofr) -> int {
285-
if (auto attr = dyn_cast<Attribute>(ofr)) {
286-
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
287-
if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
288-
isIndex = intType.isIndex();
289-
return intType.getWidth();
290-
}
291-
}
284+
auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
285+
if (auto intAttr =
286+
dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
287+
if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
288+
return std::make_tuple(intType.getWidth(), intType.isIndex());
292289
} else {
293290
auto val = cast<Value>(ofr);
294-
if (auto intType = dyn_cast<IntegerType>(val.getType())) {
295-
isIndex = intType.isIndex();
296-
return intType.getWidth();
297-
}
291+
if (auto intType = dyn_cast<IntegerType>(val.getType()))
292+
return std::make_tuple(intType.getWidth(), intType.isIndex());
298293
}
299-
return IndexType::kInternalStorageBitWidth;
294+
return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
300295
};
301-
int bitwidth = getBitwidth(lb);
302-
assert(bitwidth == getBitwidth(ub) &&
303-
"lb and ub must have the same bitwidth");
296+
auto [bitwidth, isIndex] = getBitwidth(lb);
297+
// This would better be an assert, but unfortunately it breaks scf.for_all
298+
// which is missing attributes and SSA value optionally for its bounds, and
299+
// uses Index type for the dynamic bounds but i64 for the static bounds. This
300+
// is broken...
301+
if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
302+
LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
303+
<< lb;
304+
return std::nullopt;
305+
}
304306
if (lb == ub)
305307
return APInt(bitwidth, 0);
306308

0 commit comments

Comments
 (0)