Skip to content

Commit 1f8358b

Browse files
committed
comments
1 parent b427553 commit 1f8358b

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
195195
DataFlowSolver &solver;
196196
};
197197

198-
static Type checkArithType(Type type, unsigned targetBitwidth) {
198+
/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199+
static Type checkIntType(Type type, unsigned targetBitwidth) {
199200
type = getElementTypeOrSelf(type);
200201
if (isa<IndexType>(type))
201202
return type;
@@ -207,6 +208,9 @@ static Type checkArithType(Type type, unsigned targetBitwidth) {
207208
return nullptr;
208209
}
209210

211+
/// Check if op have same type for all operands and results and this type
212+
/// is suitable for truncation.
213+
/// Retuns args type or empty.
210214
static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
211215
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
212216
return nullptr;
@@ -225,13 +229,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
225229
}
226230
}
227231

228-
return checkArithType(type, targetBitwidth);
232+
return checkIntType(type, targetBitwidth);
229233
}
230234

235+
/// Return union of all operands values ranges.
231236
static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
232-
ValueRange results) {
237+
ValueRange operands) {
233238
std::optional<ConstantIntRanges> ret;
234-
for (Value value : results) {
239+
for (Value value : operands) {
235240
auto *maybeInferredRange =
236241
solver.lookupState<IntegerValueRangeLattice>(value);
237242
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
@@ -249,6 +254,8 @@ static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
249254
return ret;
250255
}
251256

257+
/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
258+
/// return shaped type as well.
252259
static Type getTargetType(Type srcType, unsigned targetBitwidth) {
253260
auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
254261
if (auto shaped = dyn_cast<ShapedType>(srcType))
@@ -258,6 +265,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258265
return dstType;
259266
}
260267

268+
/// Check privided `range` is inside `smin, smax, umin, umax` bounds.
261269
static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
262270
APInt umin, APInt umax) {
263271
auto sge = [](APInt val1, APInt val2) -> bool {
@@ -300,9 +308,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
300308

301309
auto srcInt = cast<IntegerType>(srcType);
302310
auto dstInt = cast<IntegerType>(dstType);
303-
if (dstInt.getWidth() < srcInt.getWidth()) {
311+
if (dstInt.getWidth() < srcInt.getWidth())
304312
return builder.create<arith::TruncIOp>(loc, dstType, src);
305-
}
313+
306314
return builder.create<arith::ExtUIOp>(loc, dstType, src);
307315
}
308316

@@ -385,7 +393,7 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
385393
return failure();
386394

387395
for (unsigned targetBitwidth : targetBitwidths) {
388-
Type srcType = checkArithType(lhs.getType(), targetBitwidth);
396+
Type srcType = checkIntType(lhs.getType(), targetBitwidth);
389397
if (!srcType)
390398
continue;
391399

0 commit comments

Comments
 (0)