@@ -301,8 +301,12 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
301301
302302 llvm::Constant *CurrentStatePointerTLS = nullptr ;
303303
304+ // Contains the used builtins and kernels that need to be processed to
305+ // receive a pointer to the state struct.
306+ llvm::SmallVector<std::pair<llvm::Function *, StringRef>>
307+ UsedBuiltinsAndKernels;
308+
304309 // Then we iterate over all the supported builtins, find the used ones
305- llvm::SmallVector<std::pair<llvm::Function *, StringRef>> UsedBuiltins;
306310 for (const auto &Entry : BuiltinNamesMap) {
307311 auto *Glob = M.getFunction (Entry.first );
308312 if (!Glob)
@@ -331,7 +335,7 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
331335 }
332336 }
333337 }
334- UsedBuiltins .push_back ({Glob, Entry.second });
338+ UsedBuiltinsAndKernels .push_back ({Glob, Entry.second });
335339 }
336340
337341#ifdef NATIVECPU_USE_OCK
@@ -395,9 +399,10 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
395399 OldF->eraseFromParent ();
396400 NewKernels.push_back (NewF);
397401 if (!CurrentStatePointerTLS && NewF->getNumUses () > 0 )
398- // If a thread_local is not used we process called kernels along
399- // with the other builtins.
400- UsedBuiltins.push_back ({NewF, " " });
402+ // If a thread_local is not used we need to keep track of the called
403+ // kernel so we can update its call sites with the pointer to the state
404+ // struct like we do for the called builtins.
405+ UsedBuiltinsAndKernels.push_back ({NewF, " " });
401406 ModuleChanged = true ;
402407 }
403408
@@ -410,7 +415,9 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
410415
411416 // Then we iterate over all used builtins and
412417 // replace them with calls to our Native CPU functions.
413- for (const auto &Entry : UsedBuiltins) {
418+ // For the used kernels we need to replace calls to them
419+ // with calls receiving the state pointer argument.
420+ for (const auto &Entry : UsedBuiltinsAndKernels) {
414421 SmallVector<std::pair<Instruction *, Instruction *>> ToRemove;
415422 SmallVector<Function *> ToRemove2;
416423 Function *const Glob = Entry.first ;
0 commit comments