@@ -218,9 +218,10 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
218218 if (!type) {
219219 type = val.getType ();
220220 continue ;
221- } else if (type != val.getType ()) {
222- return nullptr ;
223221 }
222+
223+ if (type != val.getType ())
224+ return nullptr ;
224225 }
225226 }
226227
@@ -301,13 +302,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
301302 auto dstInt = cast<IntegerType>(dstType);
302303 if (dstInt.getWidth () < srcInt.getWidth ()) {
303304 return builder.create <arith::TruncIOp>(loc, dstType, src);
304- } else {
305- return builder.create <arith::ExtUIOp>(loc, dstType, src);
306305 }
306+ return builder.create <arith::ExtUIOp>(loc, dstType, src);
307307}
308308
309- struct NarrowElementwise final
310- : public OpTraitRewritePattern<OpTrait::Elementwise> {
309+ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
311310 NarrowElementwise (MLIRContext *context, DataFlowSolver &s,
312311 ArrayRef<unsigned > target)
313312 : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
@@ -316,7 +315,6 @@ struct NarrowElementwise final
316315 using OpTraitRewritePattern::OpTraitRewritePattern;
317316 LogicalResult matchAndRewrite (Operation *op,
318317 PatternRewriter &rewriter) const override {
319-
320318 std::optional<ConstantIntRanges> range =
321319 getOperandsRange (solver, op->getResults ());
322320 if (!range)
@@ -370,8 +368,8 @@ struct NarrowElementwise final
370368 SmallVector<unsigned , 4 > targetBitwidths;
371369};
372370
373- struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
374- NarrowCmpi (MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
371+ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
372+ NarrowCmpI (MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
375373 ArrayRef<unsigned > target)
376374 : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
377375 }
@@ -421,8 +419,8 @@ struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
421419 SmallVector<unsigned , 4 > targetBitwidths;
422420};
423421
424- struct IntRangeOptimizationsPass
425- : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
422+ struct IntRangeOptimizationsPass final
423+ : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
426424
427425 void runOnOperation () override {
428426 Operation *op = getOperation ();
@@ -446,8 +444,8 @@ struct IntRangeOptimizationsPass
446444 }
447445};
448446
449- struct IntRangeNarrowingPass
450- : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
447+ struct IntRangeNarrowingPass final
448+ : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
451449 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
452450
453451 void runOnOperation () override {
@@ -482,9 +480,9 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
482480void mlir::arith::populateIntRangeNarrowingPatterns (
483481 RewritePatternSet &patterns, DataFlowSolver &solver,
484482 ArrayRef<unsigned > bitwidthsSupported) {
485- // Cmpi uses args ranges instead of results, run it with higher benefit,
483+ // CmpI uses args ranges instead of results, run it with higher benefit,
486484 // as its argumens can be potentially replaced.
487- patterns.add <NarrowCmpi >(patterns.getContext (), /* benefit*/ 10 , solver,
485+ patterns.add <NarrowCmpI >(patterns.getContext (), /* benefit*/ 10 , solver,
488486 bitwidthsSupported);
489487
490488 patterns.add <NarrowElementwise>(patterns.getContext (), solver,
0 commit comments