1414// ===----------------------------------------------------------------------===//
1515
1616#include " llvm/SYCLLowerIR/CheckNDRangeSYCLNativeCPU.h"
17+ #include " llvm/ADT/PriorityWorklist.h"
18+ #include " llvm/ADT/SmallPtrSet.h"
1719#include " llvm/IR/CallingConv.h"
1820#include " llvm/IR/Constants.h"
1921#include " llvm/IR/DerivedTypes.h"
22+ #include " llvm/IR/Function.h"
2023#include " llvm/IR/InstrTypes.h"
24+ #include " llvm/IR/Instruction.h"
2125#include " llvm/IR/Instructions.h"
2226#include " llvm/IR/Metadata.h"
27+ #include " llvm/SYCLLowerIR/UtilsSYCLNativeCPU.h"
28+ #include " llvm/Support/Casting.h"
2329
2430using namespace llvm ;
2531
26- static std::array<const char *, 13 > NdFunctions {
32+ static std::array<const char *, 13 > NdBuiltins {
2733 " _Z23__spirv_WorkgroupSize_xv" , " _Z23__spirv_WorkgroupSize_yv" ,
2834 " _Z23__spirv_WorkgroupSize_zv" , " _Z23__spirv_NumWorkgroups_xv" ,
2935 " _Z23__spirv_NumWorkgroups_yv" , " _Z23__spirv_NumWorkgroups_zv" ,
@@ -42,6 +48,55 @@ static void addNDRangeMetadata(Function &F, bool Value) {
4248PreservedAnalyses
4349CheckNDRangeSYCLNativeCPUPass::run (Module &M, ModuleAnalysisManager &MAM) {
4450 bool ModuleChanged = false ;
51+ SmallPtrSet<Function *, 5 > NdFuncs; // Functions that use NDRange features
52+ SmallPtrSet<Function *, 5 > Visited;
53+ SmallPriorityWorklist<Function *, 5 > WorkList;
54+
55+ // Add builtins to the set of functions that may use NDRange features
56+ for (auto &FName : NdBuiltins) {
57+ auto F = M.getFunction (FName);
58+ if (F == nullptr )
59+ continue ;
60+ WorkList.insert (F);
61+ NdFuncs.insert (F);
62+ }
63+
64+ // Add users of local AS global var to the set of functions that may use
65+ // NDRange features
66+ for (auto &GV : M.globals ()) {
67+ if (GV.getAddressSpace () != sycl::utils::SyclNativeCpuLocalAS)
68+ continue ;
69+
70+ for (auto U : GV.users ()) {
71+ if (auto I = dyn_cast<Instruction>(U)) {
72+ auto F = I->getFunction ();
73+ if (F != nullptr && NdFuncs.insert (F).second ) {
74+ WorkList.insert (F);
75+ NdFuncs.insert (F);
76+ }
77+ }
78+ }
79+ }
80+
81+ // Traverse the use chain to find Functions that may use NDRange features
82+ // (or, recursively, Functions that call Functions that may use NDRange
83+ // features)
84+ while (!WorkList.empty ()) {
85+ auto F = WorkList.pop_back_val ();
86+
87+ for (User *U : F->users ()) {
88+ if (auto CI = dyn_cast<CallInst>(U)) {
89+ auto Caller = CI->getFunction ();
90+ if (!Caller)
91+ continue ;
92+ if (!Visited.contains (Caller)) {
93+ WorkList.insert (Caller);
94+ NdFuncs.insert (Caller);
95+ }
96+ }
97+ }
98+ Visited.insert (F);
99+ }
45100
46101 for (auto &F : M) {
47102 if (F.getCallingConv () == llvm::CallingConv::SPIR_KERNEL) {
@@ -55,23 +110,11 @@ CheckNDRangeSYCLNativeCPUPass::run(Module &M, ModuleAnalysisManager &MAM) {
55110 }
56111 }
57112
58- for (auto &BB : F) {
59- for (auto &I : BB) {
60- if (auto CI = dyn_cast<CallInst>(&I)) {
61- auto CalleeName = CI->getCalledFunction ()->getName ();
62- if (std::find (NdFunctions.begin (), NdFunctions.end (), CalleeName) !=
63- NdFunctions.end ()) {
64- IsNDRange = true ;
65- break ;
66- }
67- }
68- }
69- if (IsNDRange) {
70- break ;
71- }
72- }
113+ // Check if the kernel calls one of the ND Range builtins
114+ IsNDRange |= NdFuncs.contains (&F);
73115
74116 addNDRangeMetadata (F, IsNDRange);
117+ ModuleChanged = true ;
75118 }
76119 }
77120 return ModuleChanged ? PreservedAnalyses::none () : PreservedAnalyses::all ();
0 commit comments