Skip to content

Commit 9c08b91

Browse files
committed
[TSAR, Memory] Fix intersection calculation for collapsed memory locations.
1 parent 9d02f66 commit 9c08b91

File tree

1 file changed

+71
-20
lines changed

1 file changed

+71
-20
lines changed

lib/Analysis/Memory/MemoryLocationRange.cpp

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using namespace tsar;
3636
using namespace llvm;
3737

3838
#undef DEBUG_TYPE
39-
#define DEBUG_TYPE "def-mem"
39+
#define DEBUG_TYPE "inter-mem"
4040

4141
namespace {
4242
typedef int64_t ColumnT;
@@ -253,7 +253,7 @@ void delinearize(const MemoryLocationRange &From, MemoryLocationRange &What) {
253253
What.LowerBound = 0;
254254
What.UpperBound = ElemSize;
255255
What.DimList = std::move(DimList);
256-
What.Kind = LocKind::Collapsed | (What.Kind & LocKind::Hint);
256+
What.Kind = LocKind::Collapsed | (What.Kind & LocKind::Hint) | (What.Kind & LocKind::Auxiliary);
257257
What.SE = SE;
258258
What.AM = From.AM;
259259
}
@@ -355,14 +355,28 @@ struct IntersectVarInfo {
355355
/// returns llvm::None.
356356
llvm::Optional<std::pair<Value *, int64_t>>
357357
parseBoundExpression(const llvm::SCEV *S) {
358-
if (auto *Unknown = dyn_cast<SCEVUnknown>(S))
359-
return std::make_pair(Unknown->getValue(), int64_t(0));
358+
auto GetInnerSCEV = [](const SCEV *S) -> const SCEV* {
359+
while (auto Cast{dyn_cast<SCEVCastExpr>(S)})
360+
S = Cast->getOperand();
361+
return S;
362+
};
363+
if (!isa<SCEVNAryExpr>(S)) {
364+
if (isa<SCEVSignExtendExpr>(S) || isa<SCEVZeroExtendExpr>(S)) {
365+
return parseBoundExpression(GetInnerSCEV(S));
366+
}
367+
if (auto *Unknown = dyn_cast<SCEVUnknown>(S))
368+
return std::make_pair(Unknown->getValue(), int64_t(0));
369+
}
360370
auto *NAry = dyn_cast<SCEVNAryExpr>(S);
361371
if (!NAry || NAry->getSCEVType() != llvm::SCEVTypes::scAddExpr ||
362372
NAry->getNumOperands() != 2)
363373
return llvm::None;
364374
auto *S1 = NAry->getOperand(0);
365375
auto *S2 = NAry->getOperand(1);
376+
if (isa<SCEVSignExtendExpr>(S1) || isa<SCEVZeroExtendExpr>(S1))
377+
S1 = GetInnerSCEV(S1);
378+
if (isa<SCEVSignExtendExpr>(S2) || isa<SCEVZeroExtendExpr>(S2))
379+
S2 = GetInnerSCEV(S2);
366380
auto T1 = static_cast<SCEVTypes>(S1->getSCEVType());
367381
auto T2 = static_cast<SCEVTypes>(S2->getSCEVType());
368382
int64_t Constant = 0;
@@ -380,6 +394,13 @@ parseBoundExpression(const llvm::SCEV *S) {
380394
return std::make_pair(Variable, Constant);
381395
}
382396

397+
void shiftBounds(AssumptionBounds &AB, int64_t Offset) {
398+
if (AB.Lower)
399+
AB.Lower = *AB.Lower + Offset;
400+
if (AB.Upper)
401+
AB.Upper = *AB.Upper + Offset;
402+
}
403+
383404
inline std::function<Dimension& (llvm::SmallVectorImpl<MemoryLocationRange> *)>
384405
getGrowFunction(const MemoryLocationRange &LeftRange, std::size_t DimIdx) {
385406
assert(DimIdx < LeftRange.DimList.size() &&
@@ -485,8 +506,10 @@ IntersectionResult processOneStartOtherEndConst(
485506
auto BNItr = AM->find(BN->first);
486507
if (BPItr == AM->end() || BNItr == AM->end())
487508
return Info.UnknownIntersection;
488-
auto &BoundsP = BPItr->second;
489-
auto &BoundsN = BNItr->second;
509+
auto BoundsP = BPItr->second;
510+
auto BoundsN = BNItr->second;
511+
shiftBounds(BoundsP, BP->second);
512+
shiftBounds(BoundsN, BN->second);
490513
auto MInt = cast<SCEVConstant>(M)->getAPInt().getSExtValue();
491514
auto QInt = cast<SCEVConstant>(Q)->getAPInt().getSExtValue();
492515
if (!BoundsN.Lower || !BoundsN.Upper || !BoundsP.Lower || !BoundsP.Upper)
@@ -633,8 +656,10 @@ IntersectionResult processBothStartConst(const MemoryLocationRange &LeftRange,
633656
auto BQItr = AM->find(BQ->first);
634657
if (BNItr == AM->end() || BQItr == AM->end())
635658
return Info.UnknownIntersection;
636-
auto &BoundsN = BNItr->second;
637-
auto &BoundsQ = BQItr->second;
659+
auto BoundsN = BNItr->second;
660+
auto BoundsQ = BQItr->second;
661+
shiftBounds(BoundsN, BN->second);
662+
shiftBounds(BoundsQ, BN->second);
638663
if (BoundsN.Upper && PInt > *BoundsN.Upper) {
639664
return Info.EmptyIntersection;
640665
} else if (BoundsN.Lower && PInt <= *BoundsN.Lower) {
@@ -729,8 +754,10 @@ IntersectionResult processBothEndConst(const MemoryLocationRange &LeftRange,
729754
auto BPItr = AM->find(BP->first);
730755
if (BMItr == AM->end() || BPItr == AM->end())
731756
return Info.UnknownIntersection;
732-
auto &BoundsM = BMItr->second;
733-
auto &BoundsP = BPItr->second;
757+
auto BoundsM = BMItr->second;
758+
auto BoundsP = BPItr->second;
759+
shiftBounds(BoundsM, BM->second);
760+
shiftBounds(BoundsP, BP->second);
734761
if (BoundsM.Lower && QInt < *BoundsM.Lower) {
735762
return Info.EmptyIntersection;
736763
} else if (BoundsM.Upper && QInt >= *BoundsM.Upper) {
@@ -795,7 +822,8 @@ IntersectionResult processOneConstOtherSemiconst(
795822
auto BQItr = AM->find(BQ->first);
796823
if (BQItr == AM->end())
797824
return Info.UnknownIntersection;
798-
auto &BoundsQ = BQItr->second;
825+
auto BoundsQ = BQItr->second;
826+
shiftBounds(BoundsQ, BQ->second);
799827
if (BoundsQ.Lower && NInt < *BoundsQ.Lower) {
800828
// [max(m, p), n]
801829
Intersection.Start = MInt > PInt ? M : P;
@@ -858,7 +886,8 @@ IntersectionResult processOneConstOtherSemiconst(
858886
auto BPItr = AM->find(BP->first);
859887
if (BPItr == AM->end())
860888
return Info.UnknownIntersection;
861-
auto &BoundsP = BPItr->second;
889+
auto BoundsP = BPItr->second;
890+
shiftBounds(BoundsP, BP->second);
862891
if (BoundsP.Upper && MInt > *BoundsP.Upper) {
863892
// [m, min(n, q)]
864893
Intersection.Start = M;
@@ -961,8 +990,10 @@ IntersectionResult processOneVariableOtherSemiconst(
961990
auto BNItr = AM->find(BN->first);
962991
if (BMItr == AM->end() || BNItr == AM->end())
963992
return Info.UnknownIntersection;
964-
auto &BoundsM = BMItr->second;
965-
auto &BoundsN = BNItr->second;
993+
auto BoundsM = BMItr->second;
994+
auto BoundsN = BNItr->second;
995+
shiftBounds(BoundsM, BM->second);
996+
shiftBounds(BoundsN, BN->second);
966997
if (isa<SCEVConstant>(P)) {
967998
auto PInt = cast<SCEVConstant>(P)->getAPInt().getSExtValue();
968999
auto BQ = parseBoundExpression(Q);
@@ -971,7 +1002,8 @@ IntersectionResult processOneVariableOtherSemiconst(
9711002
auto BQItr = AM->find(BQ->first);
9721003
if (BQItr == AM->end())
9731004
return Info.UnknownIntersection;
974-
auto &BoundsQ = BQItr->second;
1005+
auto BoundsQ = BQItr->second;
1006+
shiftBounds(BoundsQ, BQ->second);
9751007
if (BoundsN.Upper && *BoundsN.Upper < PInt ||
9761008
BoundsQ.Upper && BoundsM.Lower && *BoundsQ.Upper < *BoundsM.Lower ||
9771009
CmpMQ && *CmpMQ > 0)
@@ -1046,6 +1078,21 @@ IntersectionResult processOneVariableOtherSemiconst(
10461078
Dim2.Start = addOneToSCEV(N, SE);
10471079
Dim2.End = Q;
10481080
}
1081+
} else if (CmpNQ && *CmpNQ <= 0) {
1082+
Intersection.Start = M;
1083+
Intersection.End = N;
1084+
if (SC) {
1085+
// [p, M-1]
1086+
auto &Dim1 = Grow(SC);
1087+
Dim1.Start = P;
1088+
Dim1.End = subtractOneFromSCEV(M, SE);
1089+
if (*CmpNQ < 0) {
1090+
// [N+1, Q]
1091+
auto &Dim2 = Grow(SC);
1092+
Dim2.Start = addOneToSCEV(N, SE);
1093+
Dim2.End = Q;
1094+
}
1095+
}
10491096
} else {
10501097
return Info.UnknownIntersection;
10511098
}
@@ -1060,7 +1107,8 @@ IntersectionResult processOneVariableOtherSemiconst(
10601107
auto BPItr = AM->find(BP->first);
10611108
if (BPItr == AM->end())
10621109
return Info.UnknownIntersection;
1063-
auto &BoundsP = BPItr->second;
1110+
auto BoundsP = BPItr->second;
1111+
shiftBounds(BoundsP, BP->second);
10641112
if (BoundsM.Lower && *BoundsM.Lower > QInt ||
10651113
BoundsP.Lower && BoundsN.Upper && *BoundsP.Lower > *BoundsN.Upper ||
10661114
CmpNP && *CmpNP < 0)
@@ -1175,8 +1223,10 @@ IntersectionResult processOneConstOtherVariable(
11751223
auto BQItr = AM->find(BQ->first);
11761224
if (BPItr == AM->end() || BQItr == AM->end())
11771225
return Info.UnknownIntersection;
1178-
auto &BoundsP = BPItr->second;
1179-
auto &BoundsQ = BQItr->second;
1226+
auto BoundsP = BPItr->second;
1227+
auto BoundsQ = BQItr->second;
1228+
shiftBounds(BoundsP, BP->second);
1229+
shiftBounds(BoundsQ, BQ->second);
11801230
auto MInt = cast<SCEVConstant>(M)->getAPInt().getSExtValue();
11811231
auto NInt = cast<SCEVConstant>(N)->getAPInt().getSExtValue();
11821232
if (BoundsQ.Upper && *BoundsQ.Upper < MInt ||
@@ -1364,8 +1414,9 @@ llvm::Optional<MemoryLocationRange> intersect(
13641414
LHS.DimList == RHS.DimList)
13651415
return LHS;
13661416
MemoryLocationRange Int(LHS);
1367-
assert(LHS.SE && RHS.SE && LHS.SE == RHS.SE &&
1368-
"ScalarEvolution must be specified!");
1417+
assert(LHS.SE && RHS.SE && "ScalarEvolution must be specified!");
1418+
assert(LHS.SE == RHS.SE
1419+
&& "ScalarEvolution must be common for memory locations!");
13691420
auto SE = LHS.SE;
13701421
for (std::size_t I = 0; I < LHS.DimList.size(); ++I) {
13711422
auto &Left = LHS.DimList[I];

0 commit comments

Comments
 (0)