@@ -778,42 +778,99 @@ bool GPUSanImpl::instrument() {
778778 return false ;
779779 }();
780780
781- for (Function &Fn : M)
781+ SmallVector<Function *> Kernels;
782+ for (Function &Fn : M) {
783+ if (Fn.hasFnAttribute (" kernel" ))
784+ Kernels.push_back (&Fn);
782785 if (!Fn.getName ().contains (" ompx" ) && !Fn.getName ().contains (" __kmpc" ) &&
783786 !Fn.getName ().starts_with (" rpc_" ))
784787 if (!Fn.hasFnAttribute (Attribute::DisableSanitizerInstrumentation))
785788 Changed |= instrumentFunction (Fn);
789+ }
786790
787- SmallVector<std::pair<CallBase *, ConstantInt *>> AmbiguousCallsNumbered;
791+ SmallVector<CallBase *> AmbiguousCallsOrdered;
792+ SmallVector<Constant *> AmbiguousCallsMapping;
793+ if (LocationMap.empty ())
794+ AmbiguousCalls.clear ();
788795 for (size_t I = 0 ; I < AmbiguousCalls.size (); ++I) {
789796 CallBase &CB = *AmbiguousCalls[I];
790- AmbiguousCallsNumbered.push_back ({&CB, getSourceIndex (CB)});
797+ AmbiguousCallsOrdered.push_back (&CB);
798+ AmbiguousCallsMapping.push_back (getSourceIndex (CB));
791799 }
792- IntegerType *ITy = nullptr ;
800+
801+ uint64_t AmbiguousCallsBitWidth =
802+ llvm::PowerOf2Ceil (AmbiguousCalls.size () + 1 );
803+
804+ new GlobalVariable (M, Int64Ty, /* isConstant=*/ true ,
805+ GlobalValue::ExternalLinkage,
806+ ConstantInt::get (Int64Ty, AmbiguousCallsBitWidth),
807+ " __san.num_ambiguous_calls" , nullptr ,
808+ GlobalValue::ThreadLocalMode::NotThreadLocal, 1 );
809+
793810 if (size_t NumAmbiguousCalls = AmbiguousCalls.size ()) {
794- ITy = IntegerType::get (Ctx, llvm::PowerOf2Ceil (NumAmbiguousCalls));
795- auto *ArrayTy = ArrayType::get (ITy, 1024 );
811+ {
812+ auto *ArrayTy = ArrayType::get (Int64Ty, NumAmbiguousCalls);
813+ auto *GV = new GlobalVariable (
814+ M, ArrayTy, /* isConstant=*/ true , GlobalValue::ExternalLinkage,
815+ ConstantArray::get (ArrayTy, AmbiguousCallsMapping),
816+ " __san.ambiguous_calls_mapping" , nullptr ,
817+ GlobalValue::ThreadLocalMode::NotThreadLocal, 4 );
818+ GV->setVisibility (GlobalValue::ProtectedVisibility);
819+ }
820+
821+ auto *ArrayTy = ArrayType::get (Int64Ty, 1024 );
796822 LocationsArray = new GlobalVariable (
797823 M, ArrayTy, /* isConstant=*/ false , GlobalValue::PrivateLinkage,
798824 UndefValue::get (ArrayTy), " __san.calls" , nullptr ,
799825 GlobalValue::ThreadLocalMode::NotThreadLocal, 3 );
800826
827+ auto *OldFn = M.getFunction (" __san_get_location_value" );
828+ if (OldFn)
829+ OldFn->setName (" " );
801830 Function *LocationGetter = Function::Create (
802- FunctionType::get (Int64Ty, false ), llvm:: GlobalValue::ExternalLinkage,
831+ FunctionType::get (Int64Ty, false ), GlobalValue::ExternalLinkage,
803832 " __san_get_location_value" , M);
833+ if (OldFn) {
834+ OldFn->replaceAllUsesWith (LocationGetter);
835+ OldFn->eraseFromParent ();
836+ }
804837 auto *EntryBB = BasicBlock::Create (Ctx, " entry" , LocationGetter);
805838 IRBuilder<> IRB (EntryBB);
806839 Value *Idx = IRB.CreateCall (getThreadIdFn (), {}, " san.gtid" );
807- Value *Ptr = IRB.CreateGEP (ITy, LocationsArray, {Idx});
808- auto *LocationValue = IRB.CreateLoad (ITy, Ptr);
809- IRB.CreateRet (IRB.CreateZExt (LocationValue, Int64Ty));
840+ Value *Ptr = IRB.CreateGEP (Int64Ty, LocationsArray, {Idx});
841+ auto *LocationValue = IRB.CreateLoad (Int64Ty, Ptr);
842+ IRB.CreateRet (LocationValue);
843+ }
844+
845+ Function *InitSharedFn =
846+ Function::Create (FunctionType::get (VoidTy, false ),
847+ GlobalValue::PrivateLinkage, " __san.init_shared" , &M);
848+ auto *EntryBB = BasicBlock::Create (Ctx, " entry" , InitSharedFn);
849+ IRBuilder<> IRB (EntryBB);
850+ if (!AmbiguousCalls.empty ()) {
851+ Value *Idx = IRB.CreateCall (getThreadIdFn (), {}, " san.gtid" );
852+ Value *Ptr = IRB.CreateGEP (Int64Ty, LocationsArray, {Idx});
853+ IRB.CreateStore (ConstantInt::get (Int64Ty, 0 ), Ptr);
854+ }
855+ IRB.CreateRetVoid ();
856+
857+ for (auto *KernelFn : Kernels) {
858+ IRBuilder<> IRB (&*KernelFn->getEntryBlock ().getFirstNonPHIOrDbgOrAlloca ());
859+ IRB.CreateCall (InitSharedFn, {});
810860 }
811861
812- for (auto &It : AmbiguousCallsNumbered ) {
813- IRBuilder<> IRB (It.first );
862+ for (const auto &It : llvm::enumerate (AmbiguousCallsOrdered) ) {
863+ IRBuilder<> IRB (It.value () );
814864 Value *Idx = IRB.CreateCall (getThreadIdFn (), {}, " san.gtid" );
815- Value *Ptr = IRB.CreateGEP (ITy, LocationsArray, {Idx});
816- IRB.CreateStore (It.second , Ptr);
865+ Value *Ptr = IRB.CreateGEP (Int64Ty, LocationsArray, {Idx});
866+ Value *OldVal = IRB.CreateLoad (Int64Ty, Ptr);
867+ Value *OldValShifted = IRB.CreateShl (
868+ OldVal, ConstantInt::get (Int64Ty, AmbiguousCallsBitWidth));
869+ Value *NewVal = IRB.CreateBinOp (Instruction::Or, OldValShifted,
870+ ConstantInt::get (Int64Ty, It.index () + 1 ));
871+ IRB.CreateStore (NewVal, Ptr);
872+ IRB.SetInsertPoint (It.value ()->getNextNode ());
873+ IRB.CreateStore (OldVal, Ptr);
817874 }
818875
819876 auto *NamesTy = ArrayType::get (Int8Ty, ConcatenatedString.size () + 1 );
0 commit comments