@@ -195,7 +195,8 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
195195 DataFlowSolver &solver;
196196};
197197
198- static Type checkArithType (Type type, unsigned targetBitwidth) {
198+ // / Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199+ static Type checkIntType (Type type, unsigned targetBitwidth) {
199200 type = getElementTypeOrSelf (type);
200201 if (isa<IndexType>(type))
201202 return type;
@@ -207,6 +208,9 @@ static Type checkArithType(Type type, unsigned targetBitwidth) {
207208 return nullptr ;
208209}
209210
211+ // / Check if op have same type for all operands and results and this type
212+ // / is suitable for truncation.
213+ // / Retuns args type or empty.
210214static Type checkElementwiseOpType (Operation *op, unsigned targetBitwidth) {
211215 if (op->getNumOperands () == 0 || op->getNumResults () == 0 )
212216 return nullptr ;
@@ -225,13 +229,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
225229 }
226230 }
227231
228- return checkArithType (type, targetBitwidth);
232+ return checkIntType (type, targetBitwidth);
229233}
230234
235+ // / Return union of all operands values ranges.
231236static std::optional<ConstantIntRanges> getOperandsRange (DataFlowSolver &solver,
232- ValueRange results ) {
237+ ValueRange operands ) {
233238 std::optional<ConstantIntRanges> ret;
234- for (Value value : results ) {
239+ for (Value value : operands ) {
235240 auto *maybeInferredRange =
236241 solver.lookupState <IntegerValueRangeLattice>(value);
237242 if (!maybeInferredRange || maybeInferredRange->getValue ().isUninitialized ())
@@ -249,6 +254,8 @@ static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
249254 return ret;
250255}
251256
257+ // / Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
258+ // / return shaped type as well.
252259static Type getTargetType (Type srcType, unsigned targetBitwidth) {
253260 auto dstType = IntegerType::get (srcType.getContext (), targetBitwidth);
254261 if (auto shaped = dyn_cast<ShapedType>(srcType))
@@ -258,6 +265,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258265 return dstType;
259266}
260267
268+ // / Check privided `range` is inside `smin, smax, umin, umax` bounds.
261269static bool checkRange (const ConstantIntRanges &range, APInt smin, APInt smax,
262270 APInt umin, APInt umax) {
263271 auto sge = [](APInt val1, APInt val2) -> bool {
@@ -300,9 +308,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
300308
301309 auto srcInt = cast<IntegerType>(srcType);
302310 auto dstInt = cast<IntegerType>(dstType);
303- if (dstInt.getWidth () < srcInt.getWidth ()) {
311+ if (dstInt.getWidth () < srcInt.getWidth ())
304312 return builder.create <arith::TruncIOp>(loc, dstType, src);
305- }
313+
306314 return builder.create <arith::ExtUIOp>(loc, dstType, src);
307315}
308316
@@ -385,7 +393,7 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
385393 return failure ();
386394
387395 for (unsigned targetBitwidth : targetBitwidths) {
388- Type srcType = checkArithType (lhs.getType (), targetBitwidth);
396+ Type srcType = checkIntType (lhs.getType (), targetBitwidth);
389397 if (!srcType)
390398 continue ;
391399
0 commit comments