|
47 | 47 | #include "llvm/Transforms/Utils/ModuleUtils.h" |
48 | 48 | #include <cstdint> |
49 | 49 | #include <optional> |
| 50 | +#include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| 51 | +#include "llvm/Analysis/ScalarEvolution.h" |
50 | 52 |
|
51 | 53 | using namespace llvm; |
52 | 54 |
|
@@ -168,12 +170,12 @@ class GPUSanImpl final { |
168 | 170 | PtrOrigin PO); |
169 | 171 | Value *instrumentAllocaInst(LoopInfo &LI, AllocaInst &AI); |
170 | 172 | void instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx, |
171 | | - Type &AccessTy, bool IsRead); |
| 173 | + Type &AccessTy, bool IsRead, SmallVector<GetElementPtrInst *> &GEPs); |
172 | 174 | void instrumentMultipleAccessPerBasicBlock( |
173 | 175 | LoopInfo &LI, |
174 | 176 | SmallVector<Instruction *> &AccessCausingInstructionInABasicBlock); |
175 | | - void instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI); |
176 | | - void instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI); |
| 177 | + void instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI, SmallVector<GetElementPtrInst *> &GEPs); |
| 178 | + void instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI, SmallVector<GetElementPtrInst *> &GEPs); |
177 | 179 | void instrumentGEPInst(LoopInfo &LI, GetElementPtrInst &GEP); |
178 | 180 | bool instrumentCallInst(LoopInfo &LI, CallInst &CI); |
179 | 181 | void |
@@ -914,7 +916,7 @@ Value *GPUSanImpl::replaceUserGlobals(IRBuilder<> &IRB, |
914 | 916 | } |
915 | 917 |
|
916 | 918 | void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx, |
917 | | - Type &AccessTy, bool IsRead) { |
| 919 | + Type &AccessTy, bool IsRead, SmallVector<GetElementPtrInst *> &GEPs) { |
918 | 920 | Value *PtrOp = I.getOperand(PtrIdx); |
919 | 921 | const Value *Object = nullptr; |
920 | 922 | PtrOrigin PO = getPtrOrigin(LI, PtrOp, &Object); |
@@ -943,35 +945,159 @@ void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx, |
943 | 945 | } |
944 | 946 |
|
945 | 947 | if (Loop *L = LI.getLoopFor(I.getParent())) { |
946 | | - auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I.getFunction()); |
947 | | - const auto &LD = SE.getLoopDisposition(SE.getSCEVAtScope(PtrOp, L), L); |
948 | | - } |
| 948 | + auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I.getFunction()); |
| 949 | + auto *PtrOpScev = SE.getSCEVAtScope(PtrOp, L); |
| 950 | + const auto &LD = SE.getLoopDisposition(PtrOpScev, L); |
| 951 | + SmallVector<const SCEVPredicate *, 4> Preds; |
| 952 | + SmallPtrSet< const SCEVPredicate *, 4> PredsSet; |
| 953 | + for (auto *Pred : Preds) |
| 954 | + PredsSet.insert(Pred); |
| 955 | + auto *Ex = SE.getPredicatedBackedgeTakenCount(L, Preds); |
| 956 | + |
| 957 | + errs() << "Loop Disposition: " << LD << "\n"; |
| 958 | + errs() << "ABS Expression: " << SE.getSmallConstantTripCount(L) << "\n"; |
| 959 | + const SCEVAddRecExpr *AR = SE.convertSCEVToAddRecWithPredicates(PtrOpScev, L, PredsSet); |
| 960 | + |
| 961 | + const SCEV *ScStart = AR->getStart(); |
| 962 | + const SCEV *ScEnd = AR->evaluateAtIteration(Ex, SE); |
| 963 | + const SCEV *Step = AR->getStepRecurrence(SE); |
| 964 | + |
| 965 | + // // For expressions with negative step, the upper bound is ScStart and the |
| 966 | + // // lower bound is ScEnd. |
| 967 | + if (const SCEVConstant *CStep = dyn_cast<const SCEVConstant>(Step)) { |
| 968 | + if (CStep->getValue()->isNegative()) |
| 969 | + std::swap(ScStart, ScEnd); |
| 970 | + } else { |
| 971 | + // Fallback case: the step is not constant, but the we can still |
| 972 | + // get the upper and lower bounds of the interval by using min/max |
| 973 | + // expressions. |
| 974 | + ScStart = SE.getUMinExpr(ScStart, ScEnd); |
| 975 | + ScEnd = SE.getUMaxExpr(AR->getStart(), ScEnd); |
| 976 | + } |
| 977 | + |
| 978 | + errs() << "SC step: " << *Step << "\n"; |
| 979 | + errs() << "Sc start: " << *ScStart << "\n"; |
| 980 | + errs() << "Sc end: " << *ScEnd << "\n"; |
| 981 | + ScEnd->print(errs()); |
| 982 | + errs() << "\n"; |
| 983 | + ScEnd->dump(); |
| 984 | + errs() << "\n"; |
| 985 | + |
| 986 | + ArrayRef< const SCEV * > Ops = ScEnd->operands(); |
| 987 | + errs() << "\n"; |
| 988 | + for (auto *Op : Ops){ |
| 989 | + errs() << "Operand: " << *Op << "\n"; |
| 990 | + errs() << "Operand Scev Type: " << Op->getSCEVType() << "\n"; |
| 991 | + errs() << "Operand Type: " << *Op->getType() << "\n"; |
| 992 | + } |
| 993 | + errs() << "\n"; |
| 994 | + |
| 995 | + errs() << "Scev Type: " << ScEnd->getSCEVType() << "\n"; |
| 996 | + errs() << "Type: " << *ScEnd->getType() << "\n"; |
| 997 | + errs() << "Is Non Constant Negative: " << ScEnd->isNonConstantNegative() << "\n"; |
| 998 | + errs() << "PtrOp: " << *PtrOp << "\n"; |
| 999 | + |
| 1000 | + if (Ops.size() == 2){ |
| 1001 | + const SCEV *First = Ops[0]; |
| 1002 | + //Ideally I want to get this from the SCEV analysis but there const to non-const seems to be an issue. |
| 1003 | + GetElementPtrInst *PointerOpGEP = cast<GetElementPtrInst>(PtrOp); |
| 1004 | + Value *BasePointer = PointerOpGEP->getPointerOperand(); |
| 1005 | + |
| 1006 | + errs() << "Print Base Pointer Op: " << *BasePointer << "\n"; |
| 1007 | + |
| 1008 | + const SCEVConstant *SC = dyn_cast<SCEVConstant>(First); |
| 1009 | + ConstantInt *OffsetValue = SC->getValue(); |
| 1010 | + errs() << "Constant Int value " << *OffsetValue << "\n"; |
| 1011 | + |
| 1012 | + //Create GEP |
| 1013 | + Value *GEPOutsideBB= IRB.CreateGEP(BasePointer->getType(), BasePointer, {OffsetValue}); |
| 1014 | + |
| 1015 | + GetElementPtrInst* GEPInst = dyn_cast<GetElementPtrInst>(GEPOutsideBB); |
| 1016 | + GEPs.push_back(GEPInst); |
| 1017 | + Instruction *BasePointerInst = dyn_cast<Instruction>(BasePointer); |
| 1018 | + |
| 1019 | + GetElementPtrInst *GEPToRemove = dyn_cast<GetElementPtrInst>(PtrOp); |
| 1020 | + auto It = std::find(GEPs.begin(), GEPs.end(), GEPToRemove); |
| 1021 | + if (It != GEPs.end()){ |
| 1022 | + GEPs.erase(It); |
| 1023 | + } |
| 1024 | + |
| 1025 | + GEPInst->removeFromParent(); |
| 1026 | + auto *BB= BasePointerInst->getParent(); |
| 1027 | + auto Terminator = BB->end(); |
| 1028 | + GEPInst->insertInto(BB, --Terminator); |
| 1029 | + |
| 1030 | + static int32_t ReadAccessId = -1; |
| 1031 | + static int32_t WriteAccessId = 1; |
| 1032 | + const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++; |
| 1033 | + |
| 1034 | + auto TySize = DL.getTypeStoreSize(&AccessTy); |
| 1035 | + assert(!TySize.isScalable()); |
| 1036 | + Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue()); |
| 1037 | + |
| 1038 | + Value *PlainPtrOp = IRB.CreatePointerBitCastOrAddrSpaceCast(GEPInst, getPtrTy(PO)); |
| 1039 | + |
| 1040 | + errs() << "Print Plain Ptr Op: " << *PlainPtrOp << "\n"; |
| 1041 | + |
| 1042 | + Instruction *OpInst = dyn_cast<Instruction>(PlainPtrOp); |
| 1043 | + CallInst *CB; |
| 1044 | + Value *PCVal = getPC(IRB); |
| 1045 | + Instruction *PCInst = dyn_cast<Instruction>(PCVal); |
| 1046 | + PCInst->removeFromParent(); |
| 1047 | + PCInst->insertBefore(OpInst); |
| 1048 | + if (Start) { |
| 1049 | + CB = createCall(IRB, getCheckWithBaseFn(PO), |
| 1050 | + {PlainPtrOp, Start, Length, Tag, Size, |
| 1051 | + ConstantInt::get(Int64Ty, AccessId), getSourceIndex(I), |
| 1052 | + PCInst}, |
| 1053 | + I.getName() + ".san"); |
| 1054 | + } else { |
| 1055 | + CB = createCall(IRB, getCheckFn(PO), |
| 1056 | + {PlainPtrOp, Size, ConstantInt::get(Int64Ty, AccessId), |
| 1057 | + getSourceIndex(I), PCInst}, |
| 1058 | + I.getName() + ".san"); |
| 1059 | + } |
| 1060 | + |
| 1061 | + CB->removeFromParent(); |
| 1062 | + CB->insertAfter(OpInst); |
| 1063 | + |
| 1064 | + // I.setOperand(PtrIdx, |
| 1065 | + // IRB.CreatePointerBitCastOrAddrSpaceCast(CB, PtrOp->getType())); |
949 | 1066 |
|
950 | | - static int32_t ReadAccessId = -1; |
951 | | - static int32_t WriteAccessId = 1; |
952 | | - const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++; |
| 1067 | + } |
| 1068 | + |
| 1069 | + } |
| 1070 | + else{ |
953 | 1071 |
|
954 | | - auto TySize = DL.getTypeStoreSize(&AccessTy); |
955 | | - assert(!TySize.isScalable()); |
956 | | - Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue()); |
| 1072 | + static int32_t ReadAccessId = -1; |
| 1073 | + static int32_t WriteAccessId = 1; |
| 1074 | + const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++; |
957 | 1075 |
|
958 | | - Value *PlainPtrOp = |
959 | | - IRB.CreatePointerBitCastOrAddrSpaceCast(PtrOp, getPtrTy(PO)); |
960 | | - CallInst *CB; |
961 | | - if (Start) { |
962 | | - CB = createCall(IRB, getCheckWithBaseFn(PO), |
| 1076 | + auto TySize = DL.getTypeStoreSize(&AccessTy); |
| 1077 | + assert(!TySize.isScalable()); |
| 1078 | + Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue()); |
| 1079 | + |
| 1080 | + Value *PlainPtrOp = IRB.CreatePointerBitCastOrAddrSpaceCast(PtrOp, getPtrTy(PO)); |
| 1081 | + |
| 1082 | + errs() << "Print Plain Ptr Op: " << *PlainPtrOp << "\n"; |
| 1083 | + |
| 1084 | + CallInst *CB; |
| 1085 | + if (Start) { |
| 1086 | + CB = createCall(IRB, getCheckWithBaseFn(PO), |
963 | 1087 | {PlainPtrOp, Start, Length, Tag, Size, |
964 | 1088 | ConstantInt::get(Int64Ty, AccessId), getSourceIndex(I), |
965 | 1089 | getPC(IRB)}, |
966 | 1090 | I.getName() + ".san"); |
967 | | - } else { |
968 | | - CB = createCall(IRB, getCheckFn(PO), |
| 1091 | + } else { |
| 1092 | + CB = createCall(IRB, getCheckFn(PO), |
969 | 1093 | {PlainPtrOp, Size, ConstantInt::get(Int64Ty, AccessId), |
970 | 1094 | getSourceIndex(I), getPC(IRB)}, |
971 | 1095 | I.getName() + ".san"); |
972 | | - } |
973 | | - I.setOperand(PtrIdx, |
| 1096 | + } |
| 1097 | + |
| 1098 | + I.setOperand(PtrIdx, |
974 | 1099 | IRB.CreatePointerBitCastOrAddrSpaceCast(CB, PtrOp->getType())); |
| 1100 | + } |
975 | 1101 | } |
976 | 1102 |
|
977 | 1103 | void GPUSanImpl::instrumentMultipleAccessPerBasicBlock( |
@@ -1039,7 +1165,57 @@ void GPUSanImpl::instrumentMultipleAccessPerBasicBlock( |
1039 | 1165 |
|
1040 | 1166 | if (Loop *L = LI.getLoopFor(I->getParent())) { |
1041 | 1167 | auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I->getFunction()); |
1042 | | - const auto &LD = SE.getLoopDisposition(SE.getSCEVAtScope(PtrOp, L), L); |
| 1168 | + auto *PtrOpScev = SE.getSCEVAtScope(PtrOp, L); |
| 1169 | + const auto &LD = SE.getLoopDisposition(PtrOpScev, L); |
| 1170 | + SmallVector<const SCEVPredicate *, 4> Preds; |
| 1171 | + SmallPtrSet< const SCEVPredicate *, 4> PredsSet; |
| 1172 | + for (auto *Pred : Preds) |
| 1173 | + PredsSet.insert(Pred); |
| 1174 | + auto *Ex = SE.getPredicatedBackedgeTakenCount(L, Preds); |
| 1175 | + |
| 1176 | + errs() << "Loop Disposition: " << LD << "\n"; |
| 1177 | + errs() << "ABS Expression: " << SE.getSmallConstantTripCount(L) << "\n"; |
| 1178 | + const SCEVAddRecExpr *AR = SE.convertSCEVToAddRecWithPredicates(PtrOpScev, L, PredsSet); |
| 1179 | + |
| 1180 | + const SCEV *ScStart = AR->getStart(); |
| 1181 | + const SCEV *ScEnd = AR->evaluateAtIteration(Ex, SE); |
| 1182 | + const SCEV *Step = AR->getStepRecurrence(SE); |
| 1183 | + |
| 1184 | + // // For expressions with negative step, the upper bound is ScStart and the |
| 1185 | + // // lower bound is ScEnd. |
| 1186 | + if (const SCEVConstant *CStep = dyn_cast<const SCEVConstant>(Step)) { |
| 1187 | + if (CStep->getValue()->isNegative()) |
| 1188 | + std::swap(ScStart, ScEnd); |
| 1189 | + } else { |
| 1190 | + // Fallback case: the step is not constant, but the we can still |
| 1191 | + // get the upper and lower bounds of the interval by using min/max |
| 1192 | + // expressions. |
| 1193 | + ScStart = SE.getUMinExpr(ScStart, ScEnd); |
| 1194 | + ScEnd = SE.getUMaxExpr(AR->getStart(), ScEnd); |
| 1195 | + } |
| 1196 | + |
| 1197 | + errs() << "SC step: " << *Step << "\n"; |
| 1198 | + errs() << "Sc start: " << *ScStart << "\n"; |
| 1199 | + errs() << "Sc end: " << *ScEnd << "\n"; |
| 1200 | + ScEnd->print(errs()); |
| 1201 | + errs() << "\n"; |
| 1202 | + ScEnd->dump(); |
| 1203 | + errs() << "\n"; |
| 1204 | + |
| 1205 | + ArrayRef< const SCEV * > Ops = ScEnd->operands(); |
| 1206 | + errs() << "\n"; |
| 1207 | + for (auto *Op : Ops){ |
| 1208 | + errs() << "Operand: " << *Op << "\n"; |
| 1209 | + errs() << "Operand Scev Type: " << Op->getSCEVType() << "\n"; |
| 1210 | + errs() << "Operand Type: " << *Op->getType() << "\n"; |
| 1211 | + } |
| 1212 | + errs() << "\n"; |
| 1213 | + |
| 1214 | + errs() << "Scev Type: " << ScEnd->getSCEVType() << "\n"; |
| 1215 | + errs() << "Type: " << *ScEnd->getType() << "\n"; |
| 1216 | + errs() << "Is Non Constant Negative: " << ScEnd->isNonConstantNegative() << "\n"; |
| 1217 | + errs() << "PtrOp: " << *PtrOp << "\n"; |
| 1218 | + |
1043 | 1219 | } |
1044 | 1220 |
|
1045 | 1221 | static int32_t ReadAccessId = -1; |
@@ -1269,15 +1445,15 @@ void GPUSanImpl::instrumentMultipleAccessPerBasicBlock( |
1269 | 1445 | } |
1270 | 1446 | } |
1271 | 1447 |
|
1272 | | -void GPUSanImpl::instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI) { |
| 1448 | +void GPUSanImpl::instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI, SmallVector<GetElementPtrInst *> &GEPs) { |
1273 | 1449 | instrumentAccess(LI, LoadI, LoadInst::getPointerOperandIndex(), |
1274 | 1450 | *LoadI.getType(), |
1275 | | - /*IsRead=*/true); |
| 1451 | + /*IsRead=*/true, GEPs); |
1276 | 1452 | } |
1277 | 1453 |
|
1278 | | -void GPUSanImpl::instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI) { |
| 1454 | +void GPUSanImpl::instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI, SmallVector<GetElementPtrInst *> &GEPs) { |
1279 | 1455 | instrumentAccess(LI, StoreI, StoreInst::getPointerOperandIndex(), |
1280 | | - *StoreI.getValueOperand()->getType(), /*IsRead=*/false); |
| 1456 | + *StoreI.getValueOperand()->getType(), /*IsRead=*/false, GEPs); |
1281 | 1457 | } |
1282 | 1458 |
|
1283 | 1459 | void GPUSanImpl::instrumentGEPInst(LoopInfo &LI, GetElementPtrInst &GEP) { |
@@ -1389,49 +1565,49 @@ bool GPUSanImpl::instrumentFunction(Function &Fn) { |
1389 | 1565 | } |
1390 | 1566 |
|
1391 | 1567 | // Hoist all address computation in a basic block |
1392 | | - auto GEPCopy = GEPs; |
1393 | | - while (!GEPCopy.empty()) { |
1394 | | - auto *Inst = GEPCopy.pop_back_val(); |
1395 | | - Instruction *LatestDependency = &*Inst->getParent()->begin(); |
1396 | | - for (auto *It = Inst->op_begin(); It != Inst->op_end(); It++) { |
| 1568 | + // auto GEPCopy = GEPs; |
| 1569 | + // while (!GEPCopy.empty()) { |
| 1570 | + // auto *Inst = GEPCopy.pop_back_val(); |
| 1571 | + // Instruction *LatestDependency = &*Inst->getParent()->begin(); |
| 1572 | + // for (auto *It = Inst->op_begin(); It != Inst->op_end(); It++) { |
1397 | 1573 |
|
1398 | | - if (Instruction *ToInstruction = dyn_cast<Instruction>(It)) { |
| 1574 | + // if (Instruction *ToInstruction = dyn_cast<Instruction>(It)) { |
1399 | 1575 |
|
1400 | | - if (!LatestDependency) { |
1401 | | - LatestDependency = ToInstruction; |
1402 | | - continue; |
1403 | | - } |
| 1576 | + // if (!LatestDependency) { |
| 1577 | + // LatestDependency = ToInstruction; |
| 1578 | + // continue; |
| 1579 | + // } |
1404 | 1580 |
|
1405 | | - if (ToInstruction->getParent() != Inst->getParent()) |
1406 | | - continue; |
| 1581 | + // if (ToInstruction->getParent() != Inst->getParent()) |
| 1582 | + // continue; |
1407 | 1583 |
|
1408 | | - if (LatestDependency->comesBefore(ToInstruction)) |
1409 | | - LatestDependency = ToInstruction; |
1410 | | - } |
1411 | | - } |
| 1584 | + // if (LatestDependency->comesBefore(ToInstruction)) |
| 1585 | + // LatestDependency = ToInstruction; |
| 1586 | + // } |
| 1587 | + // } |
1412 | 1588 |
|
1413 | | - Inst->moveAfter(LatestDependency); |
1414 | | - } |
| 1589 | + // Inst->moveAfter(LatestDependency); |
| 1590 | + // } |
1415 | 1591 |
|
1416 | | - bool CanMergeChecks = true; |
1417 | | - for (auto *GEP : GEPs) { |
| 1592 | + // bool CanMergeChecks = true; |
| 1593 | + // for (auto *GEP : GEPs) { |
1418 | 1594 |
|
1419 | | - if (GEP->comesBefore(LoadsStores.front())) { |
1420 | | - CanMergeChecks = CanMergeChecks && true; |
1421 | | - } else { |
1422 | | - CanMergeChecks = CanMergeChecks && false; |
1423 | | - } |
1424 | | - } |
| 1595 | + // if (GEP->comesBefore(LoadsStores.front())) { |
| 1596 | + // CanMergeChecks = CanMergeChecks && true; |
| 1597 | + // } else { |
| 1598 | + // CanMergeChecks = CanMergeChecks && false; |
| 1599 | + // } |
| 1600 | + // } |
1425 | 1601 |
|
1426 | 1602 | // check if you can merge various pointer checks. |
1427 | | - if (CanMergeChecks) { |
1428 | | - instrumentMultipleAccessPerBasicBlock(LI, LoadsStores); |
1429 | | - } else { |
| 1603 | + //if (CanMergeChecks) { |
| 1604 | + // instrumentMultipleAccessPerBasicBlock(LI, LoadsStores); |
| 1605 | + //} else { |
1430 | 1606 | for (auto *Load : Loads) |
1431 | | - instrumentLoadInst(LI, *Load); |
| 1607 | + instrumentLoadInst(LI, *Load, GEPs); |
1432 | 1608 | for (auto *Store : Stores) |
1433 | | - instrumentStoreInst(LI, *Store); |
1434 | | - } |
| 1609 | + instrumentStoreInst(LI, *Store, GEPs); |
| 1610 | + //} |
1435 | 1611 |
|
1436 | 1612 | for (auto *GEP : GEPs) |
1437 | 1613 | instrumentGEPInst(LI, *GEP); |
|
0 commit comments