@@ -581,6 +581,49 @@ std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
581581 return {lbMap, ubMap};
582582}
583583
584+ // / Express the pos^th identifier of `cst` as an affine expression in
585+ // / terms of other identifiers, if they are available in `exprs`, using the
586+ // / equality at position `idx` in `cs`t. Populates `exprs` with such an
587+ // / expression if possible, and return true. Returns false otherwise.
588+ static bool detectAsExpr (const FlatLinearConstraints &cst, unsigned pos,
589+ unsigned idx, MLIRContext *context,
590+ SmallVectorImpl<AffineExpr> &exprs) {
591+ // Initialize with a `0` expression.
592+ auto expr = getAffineConstantExpr (0 , context);
593+
594+ // Traverse `idx`th equality and construct the possible affine expression in
595+ // terms of known identifiers.
596+ unsigned j, e;
597+ for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
598+ if (j == pos)
599+ continue ;
600+ int64_t c = cst.atEq64 (idx, j);
601+ if (c == 0 )
602+ continue ;
603+ // If any of the involved IDs hasn't been found yet, we can't proceed.
604+ if (!exprs[j])
605+ break ;
606+ expr = expr + exprs[j] * c;
607+ }
608+ if (j < e)
609+ // Can't construct expression as it depends on a yet uncomputed
610+ // identifier.
611+ return false ;
612+
613+ // Add constant term to AffineExpr.
614+ expr = expr + cst.atEq64 (idx, cst.getNumVars ());
615+ int64_t vPos = cst.atEq64 (idx, pos);
616+ assert (vPos != 0 && " expected non-zero here" );
617+ if (vPos > 0 )
618+ expr = (-expr).floorDiv (vPos);
619+ else
620+ // vPos < 0.
621+ expr = expr.floorDiv (-vPos);
622+ // Successfully constructed expression.
623+ exprs[pos] = expr;
624+ return true ;
625+ }
626+
584627// / Compute a representation of `num` identifiers starting at `offset` in `cst`
585628// / as affine expressions involving other known identifiers. Each identifier's
586629// / expression (in terms of known identifiers) is populated into `memo`.
@@ -636,41 +679,13 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
636679
637680 // Detect a variable as an expression of other variables.
638681 std::optional<unsigned > idx;
639- if (!(idx = cst.findConstraintWithNonZeroAt (pos, /* isEq=*/ true ))) {
682+ if (!(idx = cst.findConstraintWithNonZeroAt (pos, /* isEq=*/ true )))
640683 continue ;
641- }
642684
643- // Build AffineExpr solving for variable 'pos' in terms of all others.
644- auto expr = getAffineConstantExpr (0 , context);
645- unsigned j, e;
646- for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
647- if (j == pos)
648- continue ;
649- int64_t c = cst.atEq64 (*idx, j);
650- if (c == 0 )
651- continue ;
652- // If any of the involved IDs hasn't been found yet, we can't proceed.
653- if (!memo[j])
654- break ;
655- expr = expr + memo[j] * c;
656- }
657- if (j < e)
658- // Can't construct expression as it depends on a yet uncomputed
659- // variable.
685+ if (detectAsExpr (cst, pos, *idx, context, memo)) {
686+ changed = true ;
660687 continue ;
661-
662- // Add constant term to AffineExpr.
663- expr = expr + cst.atEq64 (*idx, cst.getNumVars ());
664- int64_t vPos = cst.atEq64 (*idx, pos);
665- assert (vPos != 0 && " expected non-zero here" );
666- if (vPos > 0 )
667- expr = (-expr).floorDiv (vPos);
668- else
669- // vPos < 0.
670- expr = expr.floorDiv (-vPos);
671- // Successfully constructed expression.
672- memo[pos] = expr;
673- changed = true ;
688+ }
674689 }
675690 // This loop is guaranteed to reach a fixed point - since once an
676691 // variable's explicit form is computed (in memo[pos]), it's not updated
@@ -891,6 +906,185 @@ FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
891906 llvm::all_of (localExprs, [](AffineExpr expr) { return expr; }));
892907}
893908
909+ // / Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
910+ // / at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
911+ // / `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
912+ // / the equality/inequality contains any unknown id, return None. Otherwise
913+ // / return sum as `AffineExpr`.
914+ static std::optional<AffineExpr> getAsExpr (const FlatLinearConstraints &cst,
915+ unsigned pos, MLIRContext *context,
916+ ArrayRef<AffineExpr> exprs,
917+ unsigned idx, bool isEquality) {
918+ // Initialize with a `0` expression.
919+ auto expr = getAffineConstantExpr (0 , context);
920+
921+ SmallVector<int64_t , 8 > row =
922+ isEquality ? cst.getEquality64 (idx) : cst.getInequality64 (idx);
923+
924+ // Traverse `idx`th equality and construct the possible affine expression in
925+ // terms of known identifiers.
926+ unsigned j, e;
927+ for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
928+ if (j == pos)
929+ continue ;
930+ int64_t c = row[j];
931+ if (c == 0 )
932+ continue ;
933+ // If any of the involved IDs hasn't been found yet, we can't proceed.
934+ if (!exprs[j])
935+ break ;
936+ expr = expr + exprs[j] * c;
937+ }
938+ if (j < e)
939+ // Can't construct expression as it depends on a yet uncomputed
940+ // identifier.
941+ return std::nullopt ;
942+
943+ // Add constant term to AffineExpr.
944+ expr = expr + row[cst.getNumVars ()];
945+ return expr;
946+ }
947+
948+ std::optional<int64_t > FlatLinearConstraints::getConstantBoundOnDimSize (
949+ MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
950+ unsigned *minLbPos, unsigned *minUbPos) const {
951+
952+ assert (pos < getNumDimVars () && " Invalid identifier position" );
953+
954+ auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t > cst,
955+ ArrayRef<AffineExpr> whiteListCols) {
956+ for (int i = getNumDimAndSymbolVars (), e = cst.size () - 1 ; i < e; ++i) {
957+ if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant ())
958+ continue ;
959+ if (cst[i] != 0 )
960+ return false ;
961+ }
962+ return true ;
963+ };
964+
965+ // Detect the necesary local variables first.
966+ SmallVector<AffineExpr, 8 > memo (getNumVars (), AffineExpr ());
967+ (void )computeLocalVars (memo, context);
968+
969+ // Find an equality for 'pos'^th identifier that equates it to some function
970+ // of the symbolic identifiers (+ constant).
971+ int eqPos = findEqualityToConstant (pos, /* symbolic=*/ true );
972+ // If the equality involves a local var that can not be expressed as a
973+ // symbolic or constant affine expression, we bail out.
974+ if (eqPos != -1 && freeOfUnknownLocalVars (getEquality64 (eqPos), memo)) {
975+ // This identifier can only take a single value.
976+ if (lb && detectAsExpr (*this , pos, eqPos, context, memo)) {
977+ AffineExpr equalityExpr =
978+ simplifyAffineExpr (memo[pos], 0 , getNumSymbolVars ());
979+ *lb = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), equalityExpr);
980+ if (ub)
981+ *ub = *lb;
982+ }
983+ if (minLbPos)
984+ *minLbPos = eqPos;
985+ if (minUbPos)
986+ *minUbPos = eqPos;
987+ return 1 ;
988+ }
989+
990+ // Positions of constraints that are lower/upper bounds on the variable.
991+ SmallVector<unsigned , 4 > lbIndices, ubIndices;
992+
993+ // Note inequalities that give lower and upper bounds.
994+ getLowerAndUpperBoundIndices (pos, &lbIndices, &ubIndices,
995+ /* eqIndices=*/ nullptr , /* offset=*/ 0 ,
996+ /* num=*/ getNumDimVars ());
997+
998+ std::optional<int64_t > minDiff = std::nullopt ;
999+ unsigned minLbPosition = 0 , minUbPosition = 0 ;
1000+ AffineExpr minLbExpr, minUbExpr;
1001+
1002+ // Traverse each lower bound and upper bound pair, to compute the difference
1003+ // between them.
1004+ for (unsigned ubPos : ubIndices) {
1005+ // Construct sum of all ids other than `pos`th in the given upper bound row.
1006+ std::optional<AffineExpr> maybeUbExpr =
1007+ getAsExpr (*this , pos, context, memo, ubPos, /* isEquality=*/ false );
1008+ if (!maybeUbExpr.has_value () || !(*maybeUbExpr).isSymbolicOrConstant ())
1009+ continue ;
1010+
1011+ // Canonical form of an inequality that constrains the upper bound on
1012+ // an id `x_i` is of the form:
1013+ // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
1014+ // Therefore the upper bound on `x_i` will be
1015+ // `(
1016+ // sum(c_j*x_j) where j != i
1017+ // +
1018+ // c_0
1019+ // )
1020+ // /
1021+ // -(c_i)`. Divison here is a floorDiv.
1022+ AffineExpr ubExpr = maybeUbExpr->floorDiv (-atIneq64 (ubPos, pos));
1023+ assert (-atIneq64 (ubPos, pos) > 0 && " invalid upper bound index" );
1024+
1025+ // Go over each lower bound.
1026+ for (unsigned lbPos : lbIndices) {
1027+ // Construct sum of all ids other than `pos`th in the given lower bound
1028+ // row.
1029+ std::optional<AffineExpr> maybeLbExpr =
1030+ getAsExpr (*this , pos, context, memo, lbPos, /* isEquality=*/ false );
1031+ if (!maybeLbExpr.has_value () || !(*maybeLbExpr).isSymbolicOrConstant ())
1032+ continue ;
1033+
1034+ // Canonical form of an inequality that is constraining the lower bound
1035+ // on an id `x_i is of the form:
1036+ // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
1037+ // Therefore upperBound on `x_i` will be
1038+ // `-(
1039+ // sum(c_j*x_j) where j != i
1040+ // +
1041+ // c_0
1042+ // )
1043+ // /
1044+ // c_i`. Divison here is a ceilDiv.
1045+ int64_t divisor = atIneq64 (lbPos, pos);
1046+ // We convert the `ceilDiv` for floordiv with the formula:
1047+ // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
1048+ // since uniformly keeping divisons as `floorDiv` helps their
1049+ // simplification.
1050+ AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1 ).floorDiv (divisor);
1051+ assert (atIneq64 (lbPos, pos) > 0 && " invalid lower bound index" );
1052+
1053+ AffineExpr difference =
1054+ simplifyAffineExpr (ubExpr - lbExpr + 1 , 0 , getNumSymbolVars ());
1055+ // If the difference is not constant, ignore the lower bound - upper bound
1056+ // pair.
1057+ auto constantDiff = dyn_cast<AffineConstantExpr>(difference);
1058+ if (!constantDiff)
1059+ continue ;
1060+
1061+ int64_t diffValue = constantDiff.getValue ();
1062+ // This bound is non-negative by definition.
1063+ diffValue = std::max<int64_t >(diffValue, 0 );
1064+ if (!minDiff || diffValue < *minDiff) {
1065+ minDiff = diffValue;
1066+ minLbPosition = lbPos;
1067+ minUbPosition = ubPos;
1068+ minLbExpr = lbExpr;
1069+ minUbExpr = ubExpr;
1070+ }
1071+ }
1072+ }
1073+
1074+ // Populate outputs where available and needed.
1075+ if (lb && minDiff) {
1076+ *lb = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), minLbExpr);
1077+ }
1078+ if (ub)
1079+ *ub = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), minUbExpr);
1080+ if (minLbPos)
1081+ *minLbPos = minLbPosition;
1082+ if (minUbPos)
1083+ *minUbPos = minUbPosition;
1084+
1085+ return minDiff;
1086+ }
1087+
8941088IntegerSet FlatLinearConstraints::getAsIntegerSet (MLIRContext *context) const {
8951089 if (getNumConstraints () == 0 )
8961090 // Return universal set (always true): 0 == 0.
0 commit comments