@@ -194,8 +194,8 @@ class AffineIntegerRangeAnalysis
194194 // / affine.parallel (%i) = (0) to (10) {
195195 // / // getBoundsFromAffineParallel(op, 0) returns {0, 9}
196196 // / }
197- std::pair<APInt, APInt>
198- getBoundsFromAffineParallel (affine::AffineParallelOp loop, size_t idx) {
197+ ConstantIntRanges getBoundsFromAffineParallel (affine::AffineParallelOp loop,
198+ size_t idx) {
199199 SmallVector<AffineExpr> lbounds (
200200 loop.getLowerBoundsMap ().getResults ().begin (),
201201 loop.getLowerBoundsMap ().getResults ().end ());
@@ -225,11 +225,12 @@ class AffineIntegerRangeAnalysis
225225
226226 if (lb && ub) {
227227 // Create APInt values with 64 bit.
228- return {APInt (/* numBits=*/ 64 , lb.getValue (), /* isSigned=*/ true ),
229- APInt (/* numBits=*/ 64 , ub.getValue () - 1 , /* isSigned=*/ true )};
228+ return ConstantIntRanges::fromSigned (
229+ APInt (/* numBits=*/ 64 , lb.getValue (), /* isSigned=*/ true ),
230+ APInt (/* numBits=*/ 64 , ub.getValue () - 1 , /* isSigned=*/ true ));
230231 }
231232 // Return sentinel values if bounds cannot be determined
232- return { APInt::getSignedMinValue (64 ), APInt::getSignedMaxValue ( 64 )} ;
233+ return ConstantIntRanges::maxRange (64 );
233234 }
234235};
235236
@@ -279,9 +280,9 @@ void AffineIntegerRangeAnalysis::visitNonControlFlowArguments(
279280 // not expose all the necessary interfaces/methods.
280281 if (auto loop = dyn_cast<affine::AffineParallelOp>(op)) {
281282 for (Value iv : loop.getIVs ()) {
282- auto [min, max] = getBoundsFromAffineParallel (loop, 0 );
283+ ConstantIntRanges ivRange = getBoundsFromAffineParallel (
284+ loop, cast<BlockArgument>(iv).getArgNumber ());
283285 IntegerValueRangeLattice *ivEntry = getLatticeElement (iv);
284- auto ivRange = ConstantIntRanges::fromSigned (min, max);
285286 propagateIfChanged (ivEntry, ivEntry->join (IntegerValueRange{ivRange}));
286287 }
287288 return ;
@@ -614,26 +615,184 @@ struct CanonicalizeLoopsPass
614615 range.getValue ().getConstantValue ();
615616 if (!constantRangeValue.has_value ())
616617 return ;
617- if (constantRangeValue->eq (cstRhs)) {
618+ b.setInsertionPoint (cmpiOp);
619+ auto cst = b.create <arith::ConstantOp>(
620+ cmpiOp.getLoc (), b.getI1Type (),
621+ IntegerAttr::get (b.getI1Type (), !constantRangeValue->eq (cstRhs)));
622+ cmpiOp.getResult ().replaceAllUsesWith (cst);
623+ }
624+ if (pred == arith::CmpIPredicate::eq) {
625+ std::optional<APInt> constantRangeValue =
626+ range.getValue ().getConstantValue ();
627+ if (!constantRangeValue.has_value ())
628+ return ;
629+ b.setInsertionPoint (cmpiOp);
630+ auto cst = b.create <arith::ConstantOp>(
631+ cmpiOp.getLoc (), b.getI1Type (),
632+ IntegerAttr::get (b.getI1Type (), constantRangeValue->eq (cstRhs)));
633+ cmpiOp.getResult ().replaceAllUsesWith (cst);
634+ }
635+ if (pred == arith::CmpIPredicate::ult) {
636+ const APInt umax = cstRange.umax ();
637+ const APInt umin = cstRange.umin ();
638+ if (umax.ult (cstRhs)) {
639+ // Condition always true.
640+ b.setInsertionPoint (cmpiOp);
641+ auto cst = b.create <arith::ConstantOp>(
642+ cmpiOp.getLoc (), b.getI1Type (),
643+ IntegerAttr::get (b.getI1Type (), true ));
644+ cmpiOp.getResult ().replaceAllUsesWith (cst);
645+ }
646+ // range < cst -> !(range >= cst)
647+ if (umin.uge (cstRhs)) {
648+ // Condition always false.
618649 b.setInsertionPoint (cmpiOp);
619650 auto cst = b.create <arith::ConstantOp>(
620651 cmpiOp.getLoc (), b.getI1Type (),
621652 IntegerAttr::get (b.getI1Type (), false ));
622653 cmpiOp.getResult ().replaceAllUsesWith (cst);
623654 }
624655 }
625- if (pred == arith::CmpIPredicate::ult) {
656+ if (pred == arith::CmpIPredicate::ule) {
657+ const APInt umax = cstRange.umax ();
658+ const APInt umin = cstRange.umin ();
659+ if (umax.ule (cstRhs)) {
660+ // Condition always true.
661+ b.setInsertionPoint (cmpiOp);
662+ auto cst = b.create <arith::ConstantOp>(
663+ cmpiOp.getLoc (), b.getI1Type (),
664+ IntegerAttr::get (b.getI1Type (), true ));
665+ cmpiOp.getResult ().replaceAllUsesWith (cst);
666+ }
667+ // range <= cst -> !(range > cst)
668+ if (umin.ugt (cstRhs)) {
669+ // Condition always false.
670+ b.setInsertionPoint (cmpiOp);
671+ auto cst = b.create <arith::ConstantOp>(
672+ cmpiOp.getLoc (), b.getI1Type (),
673+ IntegerAttr::get (b.getI1Type (), false ));
674+ cmpiOp.getResult ().replaceAllUsesWith (cst);
675+ }
676+ }
677+ if (pred == arith::CmpIPredicate::ugt) {
626678 const APInt umax = cstRange.umax ();
627679 const APInt umin = cstRange.umin ();
628- if (umax.ult (cstRhs) && umin.ult (cstRhs)) {
680+ if (umax.ugt (cstRhs)) {
681+ // Condition always true.
682+ b.setInsertionPoint (cmpiOp);
683+ auto cst = b.create <arith::ConstantOp>(
684+ cmpiOp.getLoc (), b.getI1Type (),
685+ IntegerAttr::get (b.getI1Type (), true ));
686+ cmpiOp.getResult ().replaceAllUsesWith (cst);
687+ }
688+ // range > cst -> !(range <= cst)
689+ if (umin.ule (cstRhs)) {
690+ // Condition always false.
691+ b.setInsertionPoint (cmpiOp);
692+ auto cst = b.create <arith::ConstantOp>(
693+ cmpiOp.getLoc (), b.getI1Type (),
694+ IntegerAttr::get (b.getI1Type (), false ));
695+ cmpiOp.getResult ().replaceAllUsesWith (cst);
696+ }
697+ }
698+ if (pred == arith::CmpIPredicate::uge) {
699+ const APInt umax = cstRange.umax ();
700+ const APInt umin = cstRange.umin ();
701+ if (umax.uge (cstRhs)) {
702+ // Condition always true.
703+ b.setInsertionPoint (cmpiOp);
704+ auto cst = b.create <arith::ConstantOp>(
705+ cmpiOp.getLoc (), b.getI1Type (),
706+ IntegerAttr::get (b.getI1Type (), true ));
707+ cmpiOp.getResult ().replaceAllUsesWith (cst);
708+ }
709+ // range >= cst -> !(range < cst)
710+ if (umin.ult (cstRhs)) {
711+ // Condition always false.
712+ b.setInsertionPoint (cmpiOp);
713+ auto cst = b.create <arith::ConstantOp>(
714+ cmpiOp.getLoc (), b.getI1Type (),
715+ IntegerAttr::get (b.getI1Type (), false ));
716+ cmpiOp.getResult ().replaceAllUsesWith (cst);
717+ }
718+ }
719+
720+ if (pred == arith::CmpIPredicate::slt) {
721+ const APInt smax = cstRange.smax ();
722+ const APInt smin = cstRange.smin ();
723+ if (smax.slt (cstRhs)) {
724+ // Condition always true.
725+ b.setInsertionPoint (cmpiOp);
726+ auto cst = b.create <arith::ConstantOp>(
727+ cmpiOp.getLoc (), b.getI1Type (),
728+ IntegerAttr::get (b.getI1Type (), true ));
729+ cmpiOp.getResult ().replaceAllUsesWith (cst);
730+ }
731+ // range < cst -> !(range >= cst)
732+ if (smin.sge (cstRhs)) {
733+ // Condition always false.
734+ b.setInsertionPoint (cmpiOp);
735+ auto cst = b.create <arith::ConstantOp>(
736+ cmpiOp.getLoc (), b.getI1Type (),
737+ IntegerAttr::get (b.getI1Type (), false ));
738+ cmpiOp.getResult ().replaceAllUsesWith (cst);
739+ }
740+ }
741+ if (pred == arith::CmpIPredicate::sle) {
742+ const APInt smax = cstRange.smax ();
743+ const APInt smin = cstRange.smin ();
744+ if (smax.sle (cstRhs)) {
745+ // Condition always true.
746+ b.setInsertionPoint (cmpiOp);
747+ auto cst = b.create <arith::ConstantOp>(
748+ cmpiOp.getLoc (), b.getI1Type (),
749+ IntegerAttr::get (b.getI1Type (), true ));
750+ cmpiOp.getResult ().replaceAllUsesWith (cst);
751+ }
752+ // range <= cst -> !(range > cst)
753+ if (smin.sgt (cstRhs)) {
754+ // Condition always false.
755+ b.setInsertionPoint (cmpiOp);
756+ auto cst = b.create <arith::ConstantOp>(
757+ cmpiOp.getLoc (), b.getI1Type (),
758+ IntegerAttr::get (b.getI1Type (), false ));
759+ cmpiOp.getResult ().replaceAllUsesWith (cst);
760+ }
761+ }
762+ if (pred == arith::CmpIPredicate::sgt) {
763+ const APInt smax = cstRange.smax ();
764+ const APInt smin = cstRange.smin ();
765+ if (smax.sgt (cstRhs)) {
766+ // Condition always true.
767+ b.setInsertionPoint (cmpiOp);
768+ auto cst = b.create <arith::ConstantOp>(
769+ cmpiOp.getLoc (), b.getI1Type (),
770+ IntegerAttr::get (b.getI1Type (), true ));
771+ cmpiOp.getResult ().replaceAllUsesWith (cst);
772+ }
773+ // range > cst -> !(range <= cst)
774+ if (smin.sle (cstRhs)) {
775+ // Condition always false.
776+ b.setInsertionPoint (cmpiOp);
777+ auto cst = b.create <arith::ConstantOp>(
778+ cmpiOp.getLoc (), b.getI1Type (),
779+ IntegerAttr::get (b.getI1Type (), false ));
780+ cmpiOp.getResult ().replaceAllUsesWith (cst);
781+ }
782+ }
783+ if (pred == arith::CmpIPredicate::sge) {
784+ const APInt smax = cstRange.smax ();
785+ const APInt smin = cstRange.smin ();
786+ if (smax.sge (cstRhs)) {
629787 // Condition always true.
630788 b.setInsertionPoint (cmpiOp);
631789 auto cst = b.create <arith::ConstantOp>(
632790 cmpiOp.getLoc (), b.getI1Type (),
633791 IntegerAttr::get (b.getI1Type (), true ));
634792 cmpiOp.getResult ().replaceAllUsesWith (cst);
635793 }
636- if (!umax.ult (cstRhs) && !umin.ult (cstRhs)) {
794+ // range >= cst -> !(range < cst)
795+ if (smin.slt (cstRhs)) {
637796 // Condition always false.
638797 b.setInsertionPoint (cmpiOp);
639798 auto cst = b.create <arith::ConstantOp>(
0 commit comments