@@ -36,7 +36,7 @@ using namespace tsar;
3636using namespace llvm ;
3737
3838#undef DEBUG_TYPE
39- #define DEBUG_TYPE " def -mem"
39+ #define DEBUG_TYPE " inter -mem"
4040
4141namespace {
4242typedef int64_t ColumnT;
@@ -253,7 +253,7 @@ void delinearize(const MemoryLocationRange &From, MemoryLocationRange &What) {
253253 What.LowerBound = 0 ;
254254 What.UpperBound = ElemSize;
255255 What.DimList = std::move (DimList);
256- What.Kind = LocKind::Collapsed | (What.Kind & LocKind::Hint);
256+ What.Kind = LocKind::Collapsed | (What.Kind & LocKind::Hint) | (What. Kind & LocKind::Auxiliary) ;
257257 What.SE = SE;
258258 What.AM = From.AM ;
259259}
@@ -355,14 +355,28 @@ struct IntersectVarInfo {
355355// / returns llvm::None.
356356llvm::Optional<std::pair<Value *, int64_t >>
357357parseBoundExpression (const llvm::SCEV *S) {
358- if (auto *Unknown = dyn_cast<SCEVUnknown>(S))
359- return std::make_pair (Unknown->getValue (), int64_t (0 ));
358+ auto GetInnerSCEV = [](const SCEV *S) -> const SCEV* {
359+ while (auto Cast{dyn_cast<SCEVCastExpr>(S)})
360+ S = Cast->getOperand ();
361+ return S;
362+ };
363+ if (!isa<SCEVNAryExpr>(S)) {
364+ if (isa<SCEVSignExtendExpr>(S) || isa<SCEVZeroExtendExpr>(S)) {
365+ return parseBoundExpression (GetInnerSCEV (S));
366+ }
367+ if (auto *Unknown = dyn_cast<SCEVUnknown>(S))
368+ return std::make_pair (Unknown->getValue (), int64_t (0 ));
369+ }
360370 auto *NAry = dyn_cast<SCEVNAryExpr>(S);
361371 if (!NAry || NAry->getSCEVType () != llvm::SCEVTypes::scAddExpr ||
362372 NAry->getNumOperands () != 2 )
363373 return llvm::None;
364374 auto *S1 = NAry->getOperand (0 );
365375 auto *S2 = NAry->getOperand (1 );
376+ if (isa<SCEVSignExtendExpr>(S1) || isa<SCEVZeroExtendExpr>(S1))
377+ S1 = GetInnerSCEV (S1);
378+ if (isa<SCEVSignExtendExpr>(S2) || isa<SCEVZeroExtendExpr>(S2))
379+ S2 = GetInnerSCEV (S2);
366380 auto T1 = static_cast <SCEVTypes>(S1->getSCEVType ());
367381 auto T2 = static_cast <SCEVTypes>(S2->getSCEVType ());
368382 int64_t Constant = 0 ;
@@ -380,6 +394,13 @@ parseBoundExpression(const llvm::SCEV *S) {
380394 return std::make_pair (Variable, Constant);
381395}
382396
397+ void shiftBounds (AssumptionBounds &AB, int64_t Offset) {
398+ if (AB.Lower )
399+ AB.Lower = *AB.Lower + Offset;
400+ if (AB.Upper )
401+ AB.Upper = *AB.Upper + Offset;
402+ }
403+
383404inline std::function<Dimension& (llvm::SmallVectorImpl<MemoryLocationRange> *)>
384405getGrowFunction (const MemoryLocationRange &LeftRange, std::size_t DimIdx) {
385406 assert (DimIdx < LeftRange.DimList .size () &&
@@ -485,8 +506,10 @@ IntersectionResult processOneStartOtherEndConst(
485506 auto BNItr = AM->find (BN->first );
486507 if (BPItr == AM->end () || BNItr == AM->end ())
487508 return Info.UnknownIntersection ;
488- auto &BoundsP = BPItr->second ;
489- auto &BoundsN = BNItr->second ;
509+ auto BoundsP = BPItr->second ;
510+ auto BoundsN = BNItr->second ;
511+ shiftBounds (BoundsP, BP->second );
512+ shiftBounds (BoundsN, BN->second );
490513 auto MInt = cast<SCEVConstant>(M)->getAPInt ().getSExtValue ();
491514 auto QInt = cast<SCEVConstant>(Q)->getAPInt ().getSExtValue ();
492515 if (!BoundsN.Lower || !BoundsN.Upper || !BoundsP.Lower || !BoundsP.Upper )
@@ -633,8 +656,10 @@ IntersectionResult processBothStartConst(const MemoryLocationRange &LeftRange,
633656 auto BQItr = AM->find (BQ->first );
634657 if (BNItr == AM->end () || BQItr == AM->end ())
635658 return Info.UnknownIntersection ;
636- auto &BoundsN = BNItr->second ;
637- auto &BoundsQ = BQItr->second ;
659+ auto BoundsN = BNItr->second ;
660+ auto BoundsQ = BQItr->second ;
661+ shiftBounds (BoundsN, BN->second );
662+ shiftBounds (BoundsQ, BN->second );
638663 if (BoundsN.Upper && PInt > *BoundsN.Upper ) {
639664 return Info.EmptyIntersection ;
640665 } else if (BoundsN.Lower && PInt <= *BoundsN.Lower ) {
@@ -729,8 +754,10 @@ IntersectionResult processBothEndConst(const MemoryLocationRange &LeftRange,
729754 auto BPItr = AM->find (BP->first );
730755 if (BMItr == AM->end () || BPItr == AM->end ())
731756 return Info.UnknownIntersection ;
732- auto &BoundsM = BMItr->second ;
733- auto &BoundsP = BPItr->second ;
757+ auto BoundsM = BMItr->second ;
758+ auto BoundsP = BPItr->second ;
759+ shiftBounds (BoundsM, BM->second );
760+ shiftBounds (BoundsP, BP->second );
734761 if (BoundsM.Lower && QInt < *BoundsM.Lower ) {
735762 return Info.EmptyIntersection ;
736763 } else if (BoundsM.Upper && QInt >= *BoundsM.Upper ) {
@@ -795,7 +822,8 @@ IntersectionResult processOneConstOtherSemiconst(
795822 auto BQItr = AM->find (BQ->first );
796823 if (BQItr == AM->end ())
797824 return Info.UnknownIntersection ;
798- auto &BoundsQ = BQItr->second ;
825+ auto BoundsQ = BQItr->second ;
826+ shiftBounds (BoundsQ, BQ->second );
799827 if (BoundsQ.Lower && NInt < *BoundsQ.Lower ) {
800828 // [max(m, p), n]
801829 Intersection.Start = MInt > PInt ? M : P;
@@ -858,7 +886,8 @@ IntersectionResult processOneConstOtherSemiconst(
858886 auto BPItr = AM->find (BP->first );
859887 if (BPItr == AM->end ())
860888 return Info.UnknownIntersection ;
861- auto &BoundsP = BPItr->second ;
889+ auto BoundsP = BPItr->second ;
890+ shiftBounds (BoundsP, BP->second );
862891 if (BoundsP.Upper && MInt > *BoundsP.Upper ) {
863892 // [m, min(n, q)]
864893 Intersection.Start = M;
@@ -961,8 +990,10 @@ IntersectionResult processOneVariableOtherSemiconst(
961990 auto BNItr = AM->find (BN->first );
962991 if (BMItr == AM->end () || BNItr == AM->end ())
963992 return Info.UnknownIntersection ;
964- auto &BoundsM = BMItr->second ;
965- auto &BoundsN = BNItr->second ;
993+ auto BoundsM = BMItr->second ;
994+ auto BoundsN = BNItr->second ;
995+ shiftBounds (BoundsM, BM->second );
996+ shiftBounds (BoundsN, BN->second );
966997 if (isa<SCEVConstant>(P)) {
967998 auto PInt = cast<SCEVConstant>(P)->getAPInt ().getSExtValue ();
968999 auto BQ = parseBoundExpression (Q);
@@ -971,7 +1002,8 @@ IntersectionResult processOneVariableOtherSemiconst(
9711002 auto BQItr = AM->find (BQ->first );
9721003 if (BQItr == AM->end ())
9731004 return Info.UnknownIntersection ;
974- auto &BoundsQ = BQItr->second ;
1005+ auto BoundsQ = BQItr->second ;
1006+ shiftBounds (BoundsQ, BQ->second );
9751007 if (BoundsN.Upper && *BoundsN.Upper < PInt ||
9761008 BoundsQ.Upper && BoundsM.Lower && *BoundsQ.Upper < *BoundsM.Lower ||
9771009 CmpMQ && *CmpMQ > 0 )
@@ -1046,6 +1078,21 @@ IntersectionResult processOneVariableOtherSemiconst(
10461078 Dim2.Start = addOneToSCEV (N, SE);
10471079 Dim2.End = Q;
10481080 }
1081+ } else if (CmpNQ && *CmpNQ <= 0 ) {
1082+ Intersection.Start = M;
1083+ Intersection.End = N;
1084+ if (SC) {
1085+ // [p, M-1]
1086+ auto &Dim1 = Grow (SC);
1087+ Dim1.Start = P;
1088+ Dim1.End = subtractOneFromSCEV (M, SE);
1089+ if (*CmpNQ < 0 ) {
1090+ // [N+1, Q]
1091+ auto &Dim2 = Grow (SC);
1092+ Dim2.Start = addOneToSCEV (N, SE);
1093+ Dim2.End = Q;
1094+ }
1095+ }
10491096 } else {
10501097 return Info.UnknownIntersection ;
10511098 }
@@ -1060,7 +1107,8 @@ IntersectionResult processOneVariableOtherSemiconst(
10601107 auto BPItr = AM->find (BP->first );
10611108 if (BPItr == AM->end ())
10621109 return Info.UnknownIntersection ;
1063- auto &BoundsP = BPItr->second ;
1110+ auto BoundsP = BPItr->second ;
1111+ shiftBounds (BoundsP, BP->second );
10641112 if (BoundsM.Lower && *BoundsM.Lower > QInt ||
10651113 BoundsP.Lower && BoundsN.Upper && *BoundsP.Lower > *BoundsN.Upper ||
10661114 CmpNP && *CmpNP < 0 )
@@ -1175,8 +1223,10 @@ IntersectionResult processOneConstOtherVariable(
11751223 auto BQItr = AM->find (BQ->first );
11761224 if (BPItr == AM->end () || BQItr == AM->end ())
11771225 return Info.UnknownIntersection ;
1178- auto &BoundsP = BPItr->second ;
1179- auto &BoundsQ = BQItr->second ;
1226+ auto BoundsP = BPItr->second ;
1227+ auto BoundsQ = BQItr->second ;
1228+ shiftBounds (BoundsP, BP->second );
1229+ shiftBounds (BoundsQ, BQ->second );
11801230 auto MInt = cast<SCEVConstant>(M)->getAPInt ().getSExtValue ();
11811231 auto NInt = cast<SCEVConstant>(N)->getAPInt ().getSExtValue ();
11821232 if (BoundsQ.Upper && *BoundsQ.Upper < MInt ||
@@ -1364,8 +1414,9 @@ llvm::Optional<MemoryLocationRange> intersect(
13641414 LHS.DimList == RHS.DimList )
13651415 return LHS;
13661416 MemoryLocationRange Int (LHS);
1367- assert (LHS.SE && RHS.SE && LHS.SE == RHS.SE &&
1368- " ScalarEvolution must be specified!" );
1417+ assert (LHS.SE && RHS.SE && " ScalarEvolution must be specified!" );
1418+ assert (LHS.SE == RHS.SE
1419+ && " ScalarEvolution must be common for memory locations!" );
13691420 auto SE = LHS.SE ;
13701421 for (std::size_t I = 0 ; I < LHS.DimList .size (); ++I) {
13711422 auto &Left = LHS.DimList [I];
0 commit comments