4343#include " llvm/Support/Debug.h"
4444#include " llvm/Support/Path.h"
4545#include " llvm/Support/raw_ostream.h"
46+ #include " llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h"
4647#include " llvm/Transforms/Utils/EscapeEnumerator.h"
4748#include " llvm/Transforms/Utils/Instrumentation.h"
4849#include " llvm/Transforms/Utils/Local.h"
@@ -52,13 +53,6 @@ using namespace llvm;
5253
5354#define DEBUG_TYPE " tsan"
5455
55- // Spir memory address space
56- static constexpr unsigned kSpirOffloadPrivateAS = 0 ;
57- static constexpr unsigned kSpirOffloadGlobalAS = 1 ;
58- static constexpr unsigned kSpirOffloadConstantAS = 2 ;
59- static constexpr unsigned kSpirOffloadLocalAS = 3 ;
60- static constexpr unsigned kSpirOffloadGenericAS = 4 ;
61-
6256static cl::opt<bool > ClInstrumentMemoryAccesses (
6357 " tsan-instrument-memory-accesses" , cl::init(true ),
6458 cl::desc(" Instrument memory accesses" ), cl::Hidden);
@@ -127,6 +121,8 @@ struct ThreadSanitizerOnSpirv {
127121
128122 void appendDebugInfoToArgs (Instruction *I, SmallVectorImpl<Value *> &Args);
129123
124+ bool isUnsupportedSPIRAccess (Value *Addr, Instruction *Inst);
125+
130126private:
131127 void instrumentGlobalVariables ();
132128
@@ -383,6 +379,38 @@ void ThreadSanitizerOnSpirv::appendDebugInfoToArgs(
383379 Args.push_back (ConstantExpr::getPointerCast (FuncNameGV, ConstASPtrTy));
384380}
385381
382+ bool ThreadSanitizerOnSpirv::isUnsupportedSPIRAccess (Value *Addr,
383+ Instruction *Inst) {
384+ auto *OrigValue = getUnderlyingObject (Addr);
385+ if (OrigValue->getName ().starts_with (" __spirv_BuiltIn" ))
386+ return true ;
387+
388+ // Ignore load/store for target ext type since we can't know exactly what size
389+ // it is.
390+ if (auto *SI = dyn_cast<StoreInst>(Inst))
391+ if (getTargetExtType (SI->getValueOperand ()->getType ()) ||
392+ isJointMatrixAccess (SI->getPointerOperand ()))
393+ return true ;
394+
395+ if (auto *LI = dyn_cast<LoadInst>(Inst))
396+ if (getTargetExtType (Inst->getType ()) ||
397+ isJointMatrixAccess (LI->getPointerOperand ()))
398+ return true ;
399+
400+ auto AddrAS = cast<PointerType>(Addr->getType ()->getScalarType ())
401+ ->getPointerAddressSpace ();
402+ switch (AddrAS) {
403+ case kSpirOffloadPrivateAS :
404+ case kSpirOffloadLocalAS :
405+ case kSpirOffloadConstantAS :
406+ return true ;
407+ case kSpirOffloadGlobalAS :
408+ case kSpirOffloadGenericAS :
409+ return false ;
410+ }
411+ return false ;
412+ }
413+
386414bool ThreadSanitizerOnSpirv::isSupportedSPIRKernel (Function &F) {
387415
388416 if (!F.hasFnAttribute (Attribute::SanitizeThread) ||
@@ -709,30 +737,12 @@ static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
709737 }
710738 }
711739
712- if (Triple (M->getTargetTriple ()).isSPIROrSPIRV ()) {
713- auto *OrigValue = getUnderlyingObject (Addr);
714- if (OrigValue->getName ().starts_with (" __spirv_BuiltIn" ))
740+ // Do not instrument accesses from different address spaces; we cannot deal
741+ // with them.
742+ if (Addr) {
743+ Type *PtrTy = cast<PointerType>(Addr->getType ()->getScalarType ());
744+ if (PtrTy->getPointerAddressSpace () != 0 )
715745 return false ;
716-
717- auto AddrAS = cast<PointerType>(Addr->getType ()->getScalarType ())
718- ->getPointerAddressSpace ();
719- switch (AddrAS) {
720- case kSpirOffloadPrivateAS :
721- case kSpirOffloadLocalAS :
722- case kSpirOffloadConstantAS :
723- return false ;
724- case kSpirOffloadGlobalAS :
725- case kSpirOffloadGenericAS :
726- return true ;
727- }
728- } else {
729- // Do not instrument accesses from different address spaces; we cannot deal
730- // with them.
731- if (Addr) {
732- Type *PtrTy = cast<PointerType>(Addr->getType ()->getScalarType ());
733- if (PtrTy->getPointerAddressSpace () != 0 )
734- return false ;
735- }
736746 }
737747
738748 return true ;
@@ -781,7 +791,10 @@ void ThreadSanitizer::chooseInstructionsToInstrument(
781791 Value *Addr = IsWrite ? cast<StoreInst>(I)->getPointerOperand ()
782792 : cast<LoadInst>(I)->getPointerOperand ();
783793
784- if (!shouldInstrumentReadWriteFromAddress (I->getModule (), Addr))
794+ if (Spirv) {
795+ if (Spirv->isUnsupportedSPIRAccess (Addr, I))
796+ continue ;
797+ } else if (!shouldInstrumentReadWriteFromAddress (I->getModule (), Addr))
785798 continue ;
786799
787800 if (!IsWrite) {
@@ -890,7 +903,8 @@ bool ThreadSanitizer::sanitizeFunction(Function &F,
890903 else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst))
891904 LocalLoadsAndStores.push_back (&Inst);
892905 else if (Spirv && isa<AllocaInst>(Inst) &&
893- cast<AllocaInst>(Inst).getAllocatedType ()->isSized ())
906+ cast<AllocaInst>(Inst).getAllocatedType ()->isSized () &&
907+ !getTargetExtType (cast<AllocaInst>(Inst).getAllocatedType ()))
894908 Allocas.push_back (&Inst);
895909 else if ((isa<CallInst>(Inst) && !isa<DbgInfoIntrinsic>(Inst)) ||
896910 isa<InvokeInst>(Inst)) {
0 commit comments