Skip to content

Commit d1bc17c

Browse files
committed
LogicalResult
1 parent 262e991 commit d1bc17c

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -196,23 +196,24 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
196196
};
197197

198198
/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199-
static bool checkIntType(Type type, unsigned targetBitwidth) {
199+
static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
200200
Type elemType = getElementTypeOrSelf(type);
201201
if (isa<IndexType>(elemType))
202-
return true;
202+
return success();
203203

204204
if (auto intType = dyn_cast<IntegerType>(elemType))
205205
if (intType.getWidth() > targetBitwidth)
206-
return true;
206+
return success();
207207

208-
return false;
208+
return failure();
209209
}
210210

211211
/// Check if op have same type for all operands and results and this type
212212
/// is suitable for truncation.
213-
static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
213+
static LogicalResult checkElementwiseOpType(Operation *op,
214+
unsigned targetBitwidth) {
214215
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
215-
return false;
216+
return failure();
216217

217218
Type type;
218219
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
@@ -222,7 +223,7 @@ static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
222223
}
223224

224225
if (type != val.getType())
225-
return false;
226+
return failure();
226227
}
227228

228229
return checkIntType(type, targetBitwidth);
@@ -258,8 +259,8 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258259
}
259260

260261
/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
261-
static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
262-
APInt umin, APInt umax) {
262+
static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
263+
APInt smax, APInt umin, APInt umax) {
263264
auto sge = [](APInt val1, APInt val2) -> bool {
264265
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
265266
val1 = val1.sext(width);
@@ -284,8 +285,8 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
284285
val2 = val2.zext(width);
285286
return val1.ule(val2);
286287
};
287-
return sge(range.smin(), smin) && sle(range.smax(), smax) &&
288-
uge(range.umin(), umin) && ule(range.umax(), umax);
288+
return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
289+
uge(range.umin(), umin) && ule(range.umax(), umax));
289290
}
290291

291292
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
@@ -324,7 +325,7 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
324325
return failure();
325326

326327
for (unsigned targetBitwidth : targetBitwidths) {
327-
if (!checkElementwiseOpType(op, targetBitwidth))
328+
if (failed(checkElementwiseOpType(op, targetBitwidth)))
328329
continue;
329330

330331
Type srcType = op->getResult(0).getType();
@@ -337,7 +338,7 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
337338
auto smax = APInt::getSignedMaxValue(targetBitwidth);
338339
auto umin = APInt::getMinValue(targetBitwidth);
339340
auto umax = APInt::getMaxValue(targetBitwidth);
340-
if (!checkRange(*range, smin, smax, umin, umax))
341+
if (failed(checkRange(*range, smin, smax, umin, umax)))
341342
continue;
342343

343344
Type targetType = getTargetType(srcType, targetBitwidth);
@@ -388,14 +389,14 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
388389

389390
for (unsigned targetBitwidth : targetBitwidths) {
390391
Type srcType = lhs.getType();
391-
if (!checkIntType(srcType, targetBitwidth))
392+
if (failed(checkIntType(srcType, targetBitwidth)))
392393
continue;
393394

394395
auto smin = APInt::getSignedMinValue(targetBitwidth);
395396
auto smax = APInt::getSignedMaxValue(targetBitwidth);
396397
auto umin = APInt::getMinValue(targetBitwidth);
397398
auto umax = APInt::getMaxValue(targetBitwidth);
398-
if (!checkRange(*range, smin, smax, umin, umax))
399+
if (failed(checkRange(*range, smin, smax, umin, umax)))
399400
continue;
400401

401402
Type targetType = getTargetType(srcType, targetBitwidth);

0 commit comments

Comments
 (0)