@@ -71,10 +71,10 @@ sycl::getKernelNamesUsingImplicitLocalMem(const Module &M) {
7171 return -1 ;
7272 };
7373 llvm::for_each (M.functions (), [&](const Function &F) {
74- if (F.getCallingConv () == CallingConv::SPIR_KERNEL &&
75- F.hasFnAttribute (WORK_GROUP_STATIC_ATTR)) {
74+ if (F.getCallingConv () == CallingConv::SPIR_KERNEL) {
7675 int ArgPos = GetArgumentPos (F);
77- SPIRKernelNames.emplace_back (F.getName (), ArgPos);
76+ if (ArgPos >= 0 )
77+ SPIRKernelNames.emplace_back (F.getName (), ArgPos);
7878 }
7979 });
8080 }
@@ -184,11 +184,34 @@ lowerDynamicLocalMemCallDirect(CallInst *CI, Triple TT,
184184
185185static void lowerLocalMemCall (Function *LocalMemAllocFunc,
186186 std::function<void (CallInst *CI)> TransformCall) {
187+ static SmallPtrSet<Function *, 16 > FuncsCache;
187188 SmallVector<CallInst *, 4 > DelCalls;
188189 for (User *U : LocalMemAllocFunc->users ()) {
189190 auto *CI = cast<CallInst>(U);
190191 TransformCall (CI);
191192 DelCalls.push_back (CI);
193+ if (!FuncsCache.insert (CI->getFunction ()).second )
194+ continue ; // We have already traversed call graph from this function
195+
196+ SmallVector<Function *, 8 > WorkList;
197+ WorkList.push_back (CI->getFunction ());
198+ while (!WorkList.empty ()) {
199+ auto *F = WorkList.back ();
200+ WorkList.pop_back ();
201+
202+ // Mark kernel as using scrach memory if it isn't marked already
203+ if (F->getCallingConv () == CallingConv::SPIR_KERNEL &&
204+ !F->hasFnAttribute (WORK_GROUP_STATIC_ATTR))
205+ F->addFnAttr (WORK_GROUP_STATIC_ATTR);
206+
207+ for (auto *FU : F->users ()) {
208+ if (auto *UCI = dyn_cast<CallInst>(FU)) {
209+ if (FuncsCache.insert (UCI->getFunction ()).second )
210+ WorkList.push_back (UCI->getFunction ());
211+ } // Even though there could be other uses of a Function, we don't
212+ // care about them because we are only concerned about call graph
213+ }
214+ }
192215 }
193216
194217 for (auto *CI : DelCalls) {
0 commit comments