@@ -196,23 +196,24 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
196196};
197197
198198// / Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199- static bool checkIntType (Type type, unsigned targetBitwidth) {
199+ static LogicalResult checkIntType (Type type, unsigned targetBitwidth) {
200200 Type elemType = getElementTypeOrSelf (type);
201201 if (isa<IndexType>(elemType))
202- return true ;
202+ return success () ;
203203
204204 if (auto intType = dyn_cast<IntegerType>(elemType))
205205 if (intType.getWidth () > targetBitwidth)
206- return true ;
206+ return success () ;
207207
208- return false ;
208+ return failure () ;
209209}
210210
211211// / Check if op have same type for all operands and results and this type
212212// / is suitable for truncation.
213- static bool checkElementwiseOpType (Operation *op, unsigned targetBitwidth) {
213+ static LogicalResult checkElementwiseOpType (Operation *op,
214+ unsigned targetBitwidth) {
214215 if (op->getNumOperands () == 0 || op->getNumResults () == 0 )
215- return false ;
216+ return failure () ;
216217
217218 Type type;
218219 for (Value val : llvm::concat<Value>(op->getOperands (), op->getResults ())) {
@@ -222,7 +223,7 @@ static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
222223 }
223224
224225 if (type != val.getType ())
225- return false ;
226+ return failure () ;
226227 }
227228
228229 return checkIntType (type, targetBitwidth);
@@ -258,8 +259,8 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258259}
259260
260261// / Check provided `range` is inside `smin, smax, umin, umax` bounds.
261- static bool checkRange (const ConstantIntRanges &range, APInt smin, APInt smax ,
262- APInt umin, APInt umax) {
262+ static LogicalResult checkRange (const ConstantIntRanges &range, APInt smin,
263+ APInt smax, APInt umin, APInt umax) {
263264 auto sge = [](APInt val1, APInt val2) -> bool {
264265 unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
265266 val1 = val1.sext (width);
@@ -284,8 +285,8 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
284285 val2 = val2.zext (width);
285286 return val1.ule (val2);
286287 };
287- return sge (range.smin (), smin) && sle (range.smax (), smax) &&
288- uge (range.umin (), umin) && ule (range.umax (), umax);
288+ return success ( sge (range.smin (), smin) && sle (range.smax (), smax) &&
289+ uge (range.umin (), umin) && ule (range.umax (), umax) );
289290}
290291
291292static Value doCast (OpBuilder &builder, Location loc, Value src, Type dstType) {
@@ -324,7 +325,7 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
324325 return failure ();
325326
326327 for (unsigned targetBitwidth : targetBitwidths) {
327- if (! checkElementwiseOpType (op, targetBitwidth))
328+ if (failed ( checkElementwiseOpType (op, targetBitwidth) ))
328329 continue ;
329330
330331 Type srcType = op->getResult (0 ).getType ();
@@ -337,7 +338,7 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
337338 auto smax = APInt::getSignedMaxValue (targetBitwidth);
338339 auto umin = APInt::getMinValue (targetBitwidth);
339340 auto umax = APInt::getMaxValue (targetBitwidth);
340- if (! checkRange (*range, smin, smax, umin, umax))
341+ if (failed ( checkRange (*range, smin, smax, umin, umax) ))
341342 continue ;
342343
343344 Type targetType = getTargetType (srcType, targetBitwidth);
@@ -388,14 +389,14 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
388389
389390 for (unsigned targetBitwidth : targetBitwidths) {
390391 Type srcType = lhs.getType ();
391- if (! checkIntType (srcType, targetBitwidth))
392+ if (failed ( checkIntType (srcType, targetBitwidth) ))
392393 continue ;
393394
394395 auto smin = APInt::getSignedMinValue (targetBitwidth);
395396 auto smax = APInt::getSignedMaxValue (targetBitwidth);
396397 auto umin = APInt::getMinValue (targetBitwidth);
397398 auto umax = APInt::getMaxValue (targetBitwidth);
398- if (! checkRange (*range, smin, smax, umin, umax))
399+ if (failed ( checkRange (*range, smin, smax, umin, umax) ))
399400 continue ;
400401
401402 Type targetType = getTargetType (srcType, targetBitwidth);
0 commit comments