Skip to content

Commit 9b7ad4b

Browse files
author
Vidush Singhal
committed
edits
1 parent 15e6567 commit 9b7ad4b

File tree

2 files changed

+185
-94
lines changed

2 files changed

+185
-94
lines changed

llvm/lib/Transforms/Instrumentation/GPUSan.cpp

Lines changed: 134 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ class GPUSanImpl final {
288288
});
289289
}
290290

291+
FunctionCallee getCheckWithRangeGlobal() {
292+
return getOrCreateFn(CheckGlobalRangeFn[0], "ompx_check_global_range",
293+
Type::getVoidTy(Ctx),
294+
{
295+
PtrTy, /*SCEV max computed address*/
296+
PtrTy, /*SCEV min computed address*/
297+
PtrTy, /*Start of allocation address*/
298+
Int64Ty, /*Size of allocation, i.e. Length*/
299+
Int32Ty, /*Tag*/
300+
Int64Ty, /*Size of the type that is loaded/stored*/
301+
Int64Ty, /*AccessId, Read/Write*/
302+
Int64Ty, /*SourceId, Allocation source ID*/
303+
Int64Ty /*PC -- Program Counter*/
304+
});
305+
}
306+
291307
FunctionCallee getAllocationInfoFn(PtrOrigin PO) {
292308
assert(PO >= LOCAL && PO <= GLOBAL && "Origin does not need handling.");
293309
if (auto *F = M.getFunction("ompx_get_allocation_info" + getSuffix(PO)))
@@ -348,6 +364,11 @@ class GPUSanImpl final {
348364
IntegerType *Int32Ty = Type::getInt32Ty(Ctx);
349365
IntegerType *Int64Ty = Type::getInt64Ty(Ctx);
350366

367+
// Create an 8-bit integer type
368+
Type *Int8Type = llvm::Type::getInt8Ty(Ctx);
369+
// Create a pointer to the 8-bit integer type
370+
Type *Int8PtrType = llvm::PointerType::get(Int8Type, 0);
371+
351372
const DataLayout &DL = M.getDataLayout();
352373

353374
FunctionCallee NewFn[3];
@@ -363,6 +384,7 @@ class GPUSanImpl final {
363384
FunctionCallee LifetimeStartFn;
364385
FunctionCallee FreeNLocalFn;
365386
FunctionCallee ThreadIDFn;
387+
FunctionCallee CheckGlobalRangeFn[1];
366388

367389
StringMap<Value *> GlobalStringMap;
368390
struct AllocationInfoTy {
@@ -949,7 +971,16 @@ void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx,
949971

950972
if (Loop *L = LI.getLoopFor(I.getParent())) {
951973

974+
// goto handleunhoistable;
975+
976+
// TODO: handle for Local access also.
977+
if (PO != GLOBAL)
978+
goto handleunhoistable;
979+
980+
BasicBlock *CurrentBasicBlock = I.getParent();
981+
952982
auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(*I.getFunction());
983+
SCEVExpander Expander = SCEVExpander(SE, DL, "SCEVExpander");
953984
auto *PtrOpScev = SE.getSCEVAtScope(PtrOp, L);
954985
const auto &LD = SE.getLoopDisposition(PtrOpScev, L);
955986

@@ -967,103 +998,112 @@ void GPUSanImpl::instrumentAccess(LoopInfo &LI, Instruction &I, int PtrIdx,
967998
const SCEV *ScEnd = AddRecExpr->evaluateAtIteration(BackEdges, SE);
968999
const SCEV *Step = AddRecExpr->getStepRecurrence(SE);
9691000

970-
if (const SCEVConstant *ConstStep = dyn_cast<const SCEVConstant>(Step)) {
971-
if (ConstStep->getValue()->isNegative()) {
972-
std::swap(ScStart, ScEnd);
973-
} else {
974-
ScStart = SE.getUMinExpr(ScStart, ScEnd);
975-
ScEnd = SE.getUMaxExpr(AddRecExpr->getStart(), ScEnd);
976-
}
977-
}
1001+
// if (const SCEVConstant *ConstStep = dyn_cast<const SCEVConstant>(Step)) {
1002+
// if (ConstStep->getValue()->isNegative()) {
1003+
// std::swap(ScStart, ScEnd);
1004+
// } else {
1005+
// ScStart = SE.getUMinExpr(ScStart, ScEnd);
1006+
// ScEnd = SE.getUMaxExpr(AddRecExpr->getStart(), ScEnd);
1007+
// }
1008+
// }
9781009

979-
ArrayRef<const SCEV *> Operands = ScEnd->operands();
980-
// Assumption: If size of operands is two, it can be decomposed as
981-
// Base Offset and start ptr.
982-
if (Operands.size() == 2) {
983-
const SCEV *First = Operands[0];
984-
const SCEV *Second = Operands[1];
985-
986-
const SCEVConstant *SC = dyn_cast<SCEVConstant>(First);
987-
if (!SC)
988-
goto handleunhoistable;
989-
990-
const SCEVConstant *StepConst = dyn_cast<SCEVConstant>(Step);
991-
if (!StepConst)
992-
goto handleunhoistable;
993-
994-
const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(Second);
995-
if (!SU)
996-
goto handleunhoistable;
997-
998-
Value *BasePointer = SU->getValue();
999-
ConstantInt *BytesValue = SC->getValue();
1000-
ConstantInt *StepValue = StepConst->getValue();
1001-
1002-
uint64_t OffsetInt =
1003-
BytesValue->getZExtValue() / StepValue->getZExtValue();
1004-
ConstantInt *OffsetValue = ConstantInt::get(Ctx, APInt(64, OffsetInt));
1005-
Value *GEPOutsideBB =
1006-
IRB.CreateGEP(BasePointer->getType(), BasePointer, {OffsetValue});
1007-
1008-
GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEPOutsideBB);
1009-
if (!GEPInst)
1010-
goto handleunhoistable;
1011-
1012-
GEPs.push_back(GEPInst);
1013-
1014-
Instruction *BasePointerInst = dyn_cast<Instruction>(BasePointer);
1015-
if (!BasePointerInst)
1016-
goto handleunhoistable;
1017-
1018-
GetElementPtrInst *GEPToRemove = dyn_cast<GetElementPtrInst>(PtrOp);
1019-
if (!GEPToRemove)
1020-
goto handleunhoistable;
1021-
1022-
auto It = std::find(GEPs.begin(), GEPs.end(), GEPToRemove);
1023-
if (It != GEPs.end())
1024-
GEPs.erase(It);
1025-
1026-
GEPInst->removeFromParent();
1027-
auto *BB = BasePointerInst->getParent();
1028-
auto Terminator = BB->end();
1029-
GEPInst->insertInto(BB, --Terminator);
1030-
1031-
static int32_t ReadAccessId = -1;
1032-
static int32_t WriteAccessId = 1;
1033-
const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++;
1034-
1035-
auto TySize = DL.getTypeStoreSize(&AccessTy);
1036-
assert(!TySize.isScalable());
1037-
Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue());
1038-
1039-
Value *PlainPtrOp =
1040-
IRB.CreatePointerBitCastOrAddrSpaceCast(GEPInst, getPtrTy(PO));
1041-
Instruction *OpInst = dyn_cast<Instruction>(PlainPtrOp);
1042-
if (!OpInst)
1043-
goto handleunhoistable;
1044-
1045-
CallInst *CB;
1046-
Value *PCVal = getPC(IRB);
1047-
Instruction *PCInst = dyn_cast<Instruction>(PCVal);
1048-
PCInst->removeFromParent();
1049-
PCInst->insertBefore(OpInst);
1050-
1051-
if (Start) {
1052-
CB = createCall(IRB, getCheckWithBaseFn(PO),
1053-
{PlainPtrOp, Start, Length, Tag, Size,
1054-
ConstantInt::get(Int64Ty, AccessId), getSourceIndex(I),
1055-
PCInst},
1056-
I.getName() + ".san");
1057-
} else {
1058-
CB = createCall(IRB, getCheckFn(PO),
1059-
{PlainPtrOp, Size, ConstantInt::get(Int64Ty, AccessId),
1060-
getSourceIndex(I), PCInst},
1061-
I.getName() + ".san");
1062-
}
1010+
// // print some debug data.
1011+
// errs() << "ScStart: " << *ScStart << "\n";
1012+
// errs() << "ScEnd: " << *ScEnd << "\n";
1013+
// errs() << "Step: " << *Step << "\n";
1014+
1015+
if (!Expander.isSafeToExpand(ScStart))
1016+
goto handleunhoistable;
1017+
1018+
if (!Expander.isSafeToExpand(ScEnd))
1019+
goto handleunhoistable;
1020+
1021+
// We need to find a suitable Insert Point.
1022+
// Assumption: Current loop has one unique predecessor
1023+
// We can insert at the end of the basic block if it
1024+
// is not a branch instruction.
1025+
BasicBlock *ParentLoop = L->getLoopPredecessor();
10631026

1064-
CB->removeFromParent();
1065-
CB->insertAfter(OpInst);
1027+
// If there is not one unique predecessor, for now give up
1028+
// hoisting the check out of the loop.
1029+
if (!ParentLoop)
1030+
goto handleunhoistable;
1031+
1032+
// Get handle to last instruction.
1033+
auto LoopEnd = --(ParentLoop->end());
1034+
Instruction *LoopEndInst = &*LoopEnd;
1035+
1036+
Type *Int64Ty = Type::getInt64Ty(Ctx);
1037+
Value *LowerBoundCode =
1038+
Expander.expandCodeFor(ScStart, Int8PtrType, LoopEnd);
1039+
1040+
LoopEnd = --(ParentLoop->end());
1041+
Value *UpperBoundCode = Expander.expandCodeFor(ScEnd, Int8PtrType, LoopEnd);
1042+
1043+
static int32_t ReadAccessId = -1;
1044+
static int32_t WriteAccessId = 1;
1045+
const int32_t &AccessId = IsRead ? ReadAccessId-- : WriteAccessId++;
1046+
1047+
auto TySize = DL.getTypeStoreSize(&AccessTy);
1048+
assert(!TySize.isScalable());
1049+
Value *Size = ConstantInt::get(Int64Ty, TySize.getFixedValue());
1050+
1051+
LoopEnd = --(ParentLoop->end());
1052+
CallInst *CB;
1053+
Value *PCVal = getPC(IRB);
1054+
Instruction *PCInst = dyn_cast<Instruction>(PCVal);
1055+
if (!PCInst)
1056+
return;
1057+
1058+
Value *AccessIDVal = ConstantInt::get(Int64Ty, AccessId);
1059+
PCInst->removeFromParent();
1060+
PCInst->insertBefore(&*LoopEnd);
1061+
auto Callee = getCheckWithRangeGlobal();
1062+
1063+
// // print some debug data.
1064+
// errs() << "Print Callee Type: " << *Callee.getFunctionType() << "\n";
1065+
1066+
// errs() << *UpperBoundCode->getType() << "\n";
1067+
// errs() << *LowerBoundCode->getType() << "\n";
1068+
// errs() << *Start->getType() << "\n";
1069+
// errs() << *Length->getType() << "\n";
1070+
// errs() << *Tag->getType() << "\n";
1071+
// errs() << *Size->getType() << "\n";
1072+
// errs() << *AccessIDVal->getType() << "\n";
1073+
// errs() << *PCVal->getType() << "\n";
1074+
1075+
if (Start) {
1076+
CB = createCall(IRB, Callee,
1077+
{UpperBoundCode, LowerBoundCode, Start, Length, Tag, Size,
1078+
AccessIDVal, getSourceIndex(I), PCVal});
1079+
} else {
1080+
CB = createCall(IRB, Callee,
1081+
{UpperBoundCode, LowerBoundCode, Start, Length, Tag, Size,
1082+
AccessIDVal, getSourceIndex(I), PCVal});
10661083
}
1084+
CB->removeFromParent();
1085+
CB->insertAfter(PCInst);
1086+
1087+
// Still need to get the real pointer from the pointer op.
1088+
// Convert fake pointer to real pointer.
1089+
Value *PlainPtrOp =
1090+
IRB.CreatePointerBitCastOrAddrSpaceCast(PtrOp, getPtrTy(PO));
1091+
auto *CBUnpack = createCall(IRB, getUnpackFn(PO), {PlainPtrOp, getPC(IRB)},
1092+
PtrOp->getName() + ".unpack");
1093+
1094+
I.setOperand(PtrIdx, IRB.CreatePointerBitCastOrAddrSpaceCast(
1095+
CBUnpack, PtrOp->getType()));
1096+
1097+
return;
1098+
1099+
// // print some debug info data
1100+
// for (auto It = CB->arg_begin(); It != CB->arg_end(); It++){
1101+
// auto *Op = &*It;
1102+
// if (Value *ValOp = dyn_cast<Value>(Op))
1103+
// errs() << "Print Val Type: " << *ValOp << "\n";
1104+
// }
1105+
1106+
// errs() << "Print Callee Type: " << *Callee.getFunctionType() << "\n";
10671107
}
10681108

10691109
handleunhoistable:

offload/DeviceRTL/src/Sanitizer.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,57 @@ template <AllocationKind AK> struct AllocationTracker {
169169
return utils::advancePtr(Start, Offset);
170170
}
171171

172+
[[clang::disable_sanitizer_instrumentation]] static void
173+
checkRange(_AS_PTR(void, AK) SCEVMax, _AS_PTR(void, AK) SCEVMin,
174+
_AS_PTR(void, AK) StartAddress, int64_t AllocationLength,
175+
uint32_t Tag, int64_t AccessTypeSize, int64_t AccessId,
176+
int64_t SourceId, uint64_t PC) {
177+
178+
AllocationPtrTy<AK> APSCEVMax = AllocationPtrTy<AK>::get(SCEVMax);
179+
AllocationPtrTy<AK> APSCEVMin = AllocationPtrTy<AK>::get(SCEVMin);
180+
if constexpr (AK == AllocationKind::GLOBAL)
181+
if (APSCEVMax.Magic != SanitizerConfig<AllocationKind::GLOBAL>::MAGIC)
182+
__sanitizer_trap_info_ptr->garbagePointer<AK>(
183+
APSCEVMax, (void *)SCEVMax, SourceId, PC);
184+
185+
if (APSCEVMin.Magic != SanitizerConfig<AllocationKind::GLOBAL>::MAGIC)
186+
__sanitizer_trap_info_ptr->garbagePointer<AK>(APSCEVMin, (void *)SCEVMin,
187+
SourceId, PC);
188+
189+
int64_t MaxOffset = APSCEVMax.Offset;
190+
int64_t MinOffset = APSCEVMin.Offset;
191+
if (OMP_UNLIKELY(MaxOffset > AllocationLength - AccessTypeSize ||
192+
(SanitizerConfig<AK>::useTags() &&
193+
Tag != APSCEVMax.AllocationTag))) {
194+
__sanitizer_trap_info_ptr->accessError<AK>(APSCEVMax, AccessTypeSize,
195+
AccessId, SourceId, PC);
196+
}
197+
198+
AllocationPtrTy<AK> AllocationStart =
199+
AllocationPtrTy<AK>::get(StartAddress);
200+
int AllocationStartOffset = AllocationStart.Offset;
201+
if (OMP_UNLIKELY(MinOffset < AllocationStartOffset ||
202+
(SanitizerConfig<AK>::useTags() &&
203+
Tag != APSCEVMin.AllocationTag))) {
204+
__sanitizer_trap_info_ptr->accessError<AK>(APSCEVMin, AccessTypeSize,
205+
AccessId, SourceId, PC);
206+
}
207+
}
208+
209+
[[clang::disable_sanitizer_instrumentation, gnu::flatten, gnu::always_inline,
210+
gnu::used, gnu::retain]] void
211+
ompx_check_global_range(_AS_PTR(void, AllocationKind::GLOBAL) SCEVMax,
212+
_AS_PTR(void, AllocationKind::GLOBAL) SCEVMin,
213+
_AS_PTR(void, AllocationKind::GLOBAL) StartAddress,
214+
int64_t AllocationLength, uint32_t Tag,
215+
int64_t AccessTypeSize, int64_t AccessId,
216+
int64_t SourceId, uint64_t PC) {
217+
218+
return AllocationTracker<AllocationKind::GLOBAL>::checkRange(
219+
SCEVMax, SCEVMin, StartAddress, AllocationLength, Tag, AccessTypeSize,
220+
AccessId, SourceId, PC);
221+
}
222+
172223
[[clang::disable_sanitizer_instrumentation]] static _AS_PTR(void, AK)
173224
check(_AS_PTR(void, AK) P, int64_t Size, int64_t AccessId,
174225
int64_t SourceId, uint64_t PC) {

0 commit comments

Comments
 (0)