Skip to content

Commit 2c3efed

Browse files
author
Vidush Singhal
committed
save work
1 parent 86abd53 commit 2c3efed

File tree

1 file changed

+235
-59
lines changed

1 file changed

+235
-59
lines changed

llvm/lib/Transforms/Instrumentation/GPUSan.cpp

Lines changed: 235 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
#include "llvm/Transforms/Utils/ModuleUtils.h"
4848
#include <cstdint>
4949
#include <optional>
50+
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
51+
#include "llvm/Analysis/ScalarEvolution.h"
5052

5153
using namespace llvm;
5254

@@ -168,12 +170,12 @@ class GPUSanImpl final {
168170
PtrOrigin PO);
169171
Value *instrumentAllocaInst(LoopInfo &LI, AllocaInst &AI);
170172
void instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx,
171-
Type &AccessTy, bool IsRead);
173+
Type &AccessTy, bool IsRead, SmallVector<GetElementPtrInst *> &GEPs);
172174
void instrumentMultipleAccessPerBasicBlock(
173175
LoopInfo &LI,
174176
SmallVector<Instruction *> &AccessCausingInstructionInABasicBlock);
175-
void instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI);
176-
void instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI);
177+
void instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI, SmallVector<GetElementPtrInst *> &GEPs);
178+
void instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI, SmallVector<GetElementPtrInst *> &GEPs);
177179
void instrumentGEPInst(LoopInfo &LI, GetElementPtrInst &GEP);
178180
bool instrumentCallInst(LoopInfo &LI, CallInst &CI);
179181
void
@@ -914,7 +916,7 @@ Value *GPUSanImpl::replaceUserGlobals(IRBuilder<> &IRB,
914916
}
915917

916918
void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx,
917-
Type &AccessTy, bool IsRead) {
919+
Type &AccessTy, bool IsRead, SmallVector<GetElementPtrInst *> &GEPs) {
918920
Value *PtrOp = I.getOperand(PtrIdx);
919921
const Value *Object = nullptr;
920922
PtrOrigin PO = getPtrOrigin(LI, PtrOp, &Object);
@@ -943,35 +945,159 @@ void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx,
943945
}
944946

945947
if (Loop *L = LI.getLoopFor(I.getParent())) {
946-
auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I.getFunction());
947-
const auto &LD = SE.getLoopDisposition(SE.getSCEVAtScope(PtrOp, L), L);
948-
}
948+
auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I.getFunction());
949+
auto *PtrOpScev = SE.getSCEVAtScope(PtrOp, L);
950+
const auto &LD = SE.getLoopDisposition(PtrOpScev, L);
951+
SmallVector<const SCEVPredicate *, 4> Preds;
952+
SmallPtrSet< const SCEVPredicate *, 4> PredsSet;
953+
for (auto *Pred : Preds)
954+
PredsSet.insert(Pred);
955+
auto *Ex = SE.getPredicatedBackedgeTakenCount(L, Preds);
956+
957+
errs() << "Loop Disposition: " << LD << "\n";
958+
errs() << "ABS Expression: " << SE.getSmallConstantTripCount(L) << "\n";
959+
const SCEVAddRecExpr *AR = SE.convertSCEVToAddRecWithPredicates(PtrOpScev, L, PredsSet);
960+
961+
const SCEV *ScStart = AR->getStart();
962+
const SCEV *ScEnd = AR->evaluateAtIteration(Ex, SE);
963+
const SCEV *Step = AR->getStepRecurrence(SE);
964+
965+
// // For expressions with negative step, the upper bound is ScStart and the
966+
// // lower bound is ScEnd.
967+
if (const SCEVConstant *CStep = dyn_cast<const SCEVConstant>(Step)) {
968+
if (CStep->getValue()->isNegative())
969+
std::swap(ScStart, ScEnd);
970+
} else {
971+
// Fallback case: the step is not constant, but the we can still
972+
// get the upper and lower bounds of the interval by using min/max
973+
// expressions.
974+
ScStart = SE.getUMinExpr(ScStart, ScEnd);
975+
ScEnd = SE.getUMaxExpr(AR->getStart(), ScEnd);
976+
}
977+
978+
errs() << "SC step: " << *Step << "\n";
979+
errs() << "Sc start: " << *ScStart << "\n";
980+
errs() << "Sc end: " << *ScEnd << "\n";
981+
ScEnd->print(errs());
982+
errs() << "\n";
983+
ScEnd->dump();
984+
errs() << "\n";
985+
986+
ArrayRef< const SCEV * > Ops = ScEnd->operands();
987+
errs() << "\n";
988+
for (auto *Op : Ops){
989+
errs() << "Operand: " << *Op << "\n";
990+
errs() << "Operand Scev Type: " << Op->getSCEVType() << "\n";
991+
errs() << "Operand Type: " << *Op->getType() << "\n";
992+
}
993+
errs() << "\n";
994+
995+
errs() << "Scev Type: " << ScEnd->getSCEVType() << "\n";
996+
errs() << "Type: " << *ScEnd->getType() << "\n";
997+
errs() << "Is Non Constant Negative: " << ScEnd->isNonConstantNegative() << "\n";
998+
errs() << "PtrOp: " << *PtrOp << "\n";
999+
1000+
if (Ops.size() == 2){
1001+
const SCEV *First = Ops[0];
1002+
//Ideally I want to get this from the SCEV analysis but there const to non-const seems to be an issue.
1003+
GetElementPtrInst *PointerOpGEP = cast<GetElementPtrInst>(PtrOp);
1004+
Value *BasePointer = PointerOpGEP->getPointerOperand();
1005+
1006+
errs() << "Print Base Pointer Op: " << *BasePointer << "\n";
1007+
1008+
const SCEVConstant *SC = dyn_cast<SCEVConstant>(First);
1009+
ConstantInt *OffsetValue = SC->getValue();
1010+
errs() << "Constant Int value " << *OffsetValue << "\n";
1011+
1012+
//Create GEP
1013+
Value *GEPOutsideBB= IRB.CreateGEP(BasePointer->getType(), BasePointer, {OffsetValue});
1014+
1015+
GetElementPtrInst* GEPInst = dyn_cast<GetElementPtrInst>(GEPOutsideBB);
1016+
GEPs.push_back(GEPInst);
1017+
Instruction *BasePointerInst = dyn_cast<Instruction>(BasePointer);
1018+
1019+
GetElementPtrInst *GEPToRemove = dyn_cast<GetElementPtrInst>(PtrOp);
1020+
auto It = std::find(GEPs.begin(), GEPs.end(), GEPToRemove);
1021+
if (It != GEPs.end()){
1022+
GEPs.erase(It);
1023+
}
1024+
1025+
GEPInst->removeFromParent();
1026+
auto *BB= BasePointerInst->getParent();
1027+
auto Terminator = BB->end();
1028+
GEPInst->insertInto(BB, --Terminator);
1029+
1030+
static int32_t ReadAccessId = -1;
1031+
static int32_t WriteAccessId = 1;
1032+
const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++;
1033+
1034+
auto TySize = DL.getTypeStoreSize(&AccessTy);
1035+
assert(!TySize.isScalable());
1036+
Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue());
1037+
1038+
Value *PlainPtrOp = IRB.CreatePointerBitCastOrAddrSpaceCast(GEPInst, getPtrTy(PO));
1039+
1040+
errs() << "Print Plain Ptr Op: " << *PlainPtrOp << "\n";
1041+
1042+
Instruction *OpInst = dyn_cast<Instruction>(PlainPtrOp);
1043+
CallInst *CB;
1044+
Value *PCVal = getPC(IRB);
1045+
Instruction *PCInst = dyn_cast<Instruction>(PCVal);
1046+
PCInst->removeFromParent();
1047+
PCInst->insertBefore(OpInst);
1048+
if (Start) {
1049+
CB = createCall(IRB, getCheckWithBaseFn(PO),
1050+
{PlainPtrOp, Start, Length, Tag, Size,
1051+
ConstantInt::get(Int64Ty, AccessId), getSourceIndex(I),
1052+
PCInst},
1053+
I.getName() + ".san");
1054+
} else {
1055+
CB = createCall(IRB, getCheckFn(PO),
1056+
{PlainPtrOp, Size, ConstantInt::get(Int64Ty, AccessId),
1057+
getSourceIndex(I), PCInst},
1058+
I.getName() + ".san");
1059+
}
1060+
1061+
CB->removeFromParent();
1062+
CB->insertAfter(OpInst);
1063+
1064+
// I.setOperand(PtrIdx,
1065+
// IRB.CreatePointerBitCastOrAddrSpaceCast(CB, PtrOp->getType()));
9491066

950-
static int32_t ReadAccessId = -1;
951-
static int32_t WriteAccessId = 1;
952-
const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++;
1067+
}
1068+
1069+
}
1070+
else{
9531071

954-
auto TySize = DL.getTypeStoreSize(&AccessTy);
955-
assert(!TySize.isScalable());
956-
Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue());
1072+
static int32_t ReadAccessId = -1;
1073+
static int32_t WriteAccessId = 1;
1074+
const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++;
9571075

958-
Value *PlainPtrOp =
959-
IRB.CreatePointerBitCastOrAddrSpaceCast(PtrOp, getPtrTy(PO));
960-
CallInst *CB;
961-
if (Start) {
962-
CB = createCall(IRB, getCheckWithBaseFn(PO),
1076+
auto TySize = DL.getTypeStoreSize(&AccessTy);
1077+
assert(!TySize.isScalable());
1078+
Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue());
1079+
1080+
Value *PlainPtrOp = IRB.CreatePointerBitCastOrAddrSpaceCast(PtrOp, getPtrTy(PO));
1081+
1082+
errs() << "Print Plain Ptr Op: " << *PlainPtrOp << "\n";
1083+
1084+
CallInst *CB;
1085+
if (Start) {
1086+
CB = createCall(IRB, getCheckWithBaseFn(PO),
9631087
{PlainPtrOp, Start, Length, Tag, Size,
9641088
ConstantInt::get(Int64Ty, AccessId), getSourceIndex(I),
9651089
getPC(IRB)},
9661090
I.getName() + ".san");
967-
} else {
968-
CB = createCall(IRB, getCheckFn(PO),
1091+
} else {
1092+
CB = createCall(IRB, getCheckFn(PO),
9691093
{PlainPtrOp, Size, ConstantInt::get(Int64Ty, AccessId),
9701094
getSourceIndex(I), getPC(IRB)},
9711095
I.getName() + ".san");
972-
}
973-
I.setOperand(PtrIdx,
1096+
}
1097+
1098+
I.setOperand(PtrIdx,
9741099
IRB.CreatePointerBitCastOrAddrSpaceCast(CB, PtrOp->getType()));
1100+
}
9751101
}
9761102

9771103
void GPUSanImpl::instrumentMultipleAccessPerBasicBlock(
@@ -1039,7 +1165,57 @@ void GPUSanImpl::instrumentMultipleAccessPerBasicBlock(
10391165

10401166
if (Loop *L = LI.getLoopFor(I->getParent())) {
10411167
auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I->getFunction());
1042-
const auto &LD = SE.getLoopDisposition(SE.getSCEVAtScope(PtrOp, L), L);
1168+
auto *PtrOpScev = SE.getSCEVAtScope(PtrOp, L);
1169+
const auto &LD = SE.getLoopDisposition(PtrOpScev, L);
1170+
SmallVector<const SCEVPredicate *, 4> Preds;
1171+
SmallPtrSet< const SCEVPredicate *, 4> PredsSet;
1172+
for (auto *Pred : Preds)
1173+
PredsSet.insert(Pred);
1174+
auto *Ex = SE.getPredicatedBackedgeTakenCount(L, Preds);
1175+
1176+
errs() << "Loop Disposition: " << LD << "\n";
1177+
errs() << "ABS Expression: " << SE.getSmallConstantTripCount(L) << "\n";
1178+
const SCEVAddRecExpr *AR = SE.convertSCEVToAddRecWithPredicates(PtrOpScev, L, PredsSet);
1179+
1180+
const SCEV *ScStart = AR->getStart();
1181+
const SCEV *ScEnd = AR->evaluateAtIteration(Ex, SE);
1182+
const SCEV *Step = AR->getStepRecurrence(SE);
1183+
1184+
// // For expressions with negative step, the upper bound is ScStart and the
1185+
// // lower bound is ScEnd.
1186+
if (const SCEVConstant *CStep = dyn_cast<const SCEVConstant>(Step)) {
1187+
if (CStep->getValue()->isNegative())
1188+
std::swap(ScStart, ScEnd);
1189+
} else {
1190+
// Fallback case: the step is not constant, but the we can still
1191+
// get the upper and lower bounds of the interval by using min/max
1192+
// expressions.
1193+
ScStart = SE.getUMinExpr(ScStart, ScEnd);
1194+
ScEnd = SE.getUMaxExpr(AR->getStart(), ScEnd);
1195+
}
1196+
1197+
errs() << "SC step: " << *Step << "\n";
1198+
errs() << "Sc start: " << *ScStart << "\n";
1199+
errs() << "Sc end: " << *ScEnd << "\n";
1200+
ScEnd->print(errs());
1201+
errs() << "\n";
1202+
ScEnd->dump();
1203+
errs() << "\n";
1204+
1205+
ArrayRef< const SCEV * > Ops = ScEnd->operands();
1206+
errs() << "\n";
1207+
for (auto *Op : Ops){
1208+
errs() << "Operand: " << *Op << "\n";
1209+
errs() << "Operand Scev Type: " << Op->getSCEVType() << "\n";
1210+
errs() << "Operand Type: " << *Op->getType() << "\n";
1211+
}
1212+
errs() << "\n";
1213+
1214+
errs() << "Scev Type: " << ScEnd->getSCEVType() << "\n";
1215+
errs() << "Type: " << *ScEnd->getType() << "\n";
1216+
errs() << "Is Non Constant Negative: " << ScEnd->isNonConstantNegative() << "\n";
1217+
errs() << "PtrOp: " << *PtrOp << "\n";
1218+
10431219
}
10441220

10451221
static int32_t ReadAccessId = -1;
@@ -1269,15 +1445,15 @@ void GPUSanImpl::instrumentMultipleAccessPerBasicBlock(
12691445
}
12701446
}
12711447

1272-
void GPUSanImpl::instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI) {
1448+
void GPUSanImpl::instrumentLoadInst(LoopInfo &LI, LoadInst &LoadI, SmallVector<GetElementPtrInst *> &GEPs) {
12731449
instrumentAccess(LI, LoadI, LoadInst::getPointerOperandIndex(),
12741450
*LoadI.getType(),
1275-
/*IsRead=*/true);
1451+
/*IsRead=*/true, GEPs);
12761452
}
12771453

1278-
void GPUSanImpl::instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI) {
1454+
void GPUSanImpl::instrumentStoreInst(LoopInfo &LI, StoreInst &StoreI, SmallVector<GetElementPtrInst *> &GEPs) {
12791455
instrumentAccess(LI, StoreI, StoreInst::getPointerOperandIndex(),
1280-
*StoreI.getValueOperand()->getType(), /*IsRead=*/false);
1456+
*StoreI.getValueOperand()->getType(), /*IsRead=*/false, GEPs);
12811457
}
12821458

12831459
void GPUSanImpl::instrumentGEPInst(LoopInfo &LI, GetElementPtrInst &GEP) {
@@ -1389,49 +1565,49 @@ bool GPUSanImpl::instrumentFunction(Function &Fn) {
13891565
}
13901566

13911567
// Hoist all address computation in a basic block
1392-
auto GEPCopy = GEPs;
1393-
while (!GEPCopy.empty()) {
1394-
auto *Inst = GEPCopy.pop_back_val();
1395-
Instruction *LatestDependency = &*Inst->getParent()->begin();
1396-
for (auto *It = Inst->op_begin(); It != Inst->op_end(); It++) {
1568+
// auto GEPCopy = GEPs;
1569+
// while (!GEPCopy.empty()) {
1570+
// auto *Inst = GEPCopy.pop_back_val();
1571+
// Instruction *LatestDependency = &*Inst->getParent()->begin();
1572+
// for (auto *It = Inst->op_begin(); It != Inst->op_end(); It++) {
13971573

1398-
if (Instruction *ToInstruction = dyn_cast<Instruction>(It)) {
1574+
// if (Instruction *ToInstruction = dyn_cast<Instruction>(It)) {
13991575

1400-
if (!LatestDependency) {
1401-
LatestDependency = ToInstruction;
1402-
continue;
1403-
}
1576+
// if (!LatestDependency) {
1577+
// LatestDependency = ToInstruction;
1578+
// continue;
1579+
// }
14041580

1405-
if (ToInstruction->getParent() != Inst->getParent())
1406-
continue;
1581+
// if (ToInstruction->getParent() != Inst->getParent())
1582+
// continue;
14071583

1408-
if (LatestDependency->comesBefore(ToInstruction))
1409-
LatestDependency = ToInstruction;
1410-
}
1411-
}
1584+
// if (LatestDependency->comesBefore(ToInstruction))
1585+
// LatestDependency = ToInstruction;
1586+
// }
1587+
// }
14121588

1413-
Inst->moveAfter(LatestDependency);
1414-
}
1589+
// Inst->moveAfter(LatestDependency);
1590+
// }
14151591

1416-
bool CanMergeChecks = true;
1417-
for (auto *GEP : GEPs) {
1592+
// bool CanMergeChecks = true;
1593+
// for (auto *GEP : GEPs) {
14181594

1419-
if (GEP->comesBefore(LoadsStores.front())) {
1420-
CanMergeChecks = CanMergeChecks && true;
1421-
} else {
1422-
CanMergeChecks = CanMergeChecks && false;
1423-
}
1424-
}
1595+
// if (GEP->comesBefore(LoadsStores.front())) {
1596+
// CanMergeChecks = CanMergeChecks && true;
1597+
// } else {
1598+
// CanMergeChecks = CanMergeChecks && false;
1599+
// }
1600+
// }
14251601

14261602
// check if you can merge various pointer checks.
1427-
if (CanMergeChecks) {
1428-
instrumentMultipleAccessPerBasicBlock(LI, LoadsStores);
1429-
} else {
1603+
//if (CanMergeChecks) {
1604+
// instrumentMultipleAccessPerBasicBlock(LI, LoadsStores);
1605+
//} else {
14301606
for (auto *Load : Loads)
1431-
instrumentLoadInst(LI, *Load);
1607+
instrumentLoadInst(LI, *Load, GEPs);
14321608
for (auto *Store : Stores)
1433-
instrumentStoreInst(LI, *Store);
1434-
}
1609+
instrumentStoreInst(LI, *Store, GEPs);
1610+
//}
14351611

14361612
for (auto *GEP : GEPs)
14371613
instrumentGEPInst(LI, *GEP);

0 commit comments

Comments
 (0)