Skip to content

Commit 668a2f8

Browse files
committed
Recognize free function kernels that use local work group memory and mark them with the appropriate attribute
1 parent b05f808 commit 668a2f8

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

llvm/lib/SYCLLowerIR/LowerWGLocalMemory.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ sycl::getKernelNamesUsingImplicitLocalMem(const Module &M) {
7171
return -1;
7272
};
7373
llvm::for_each(M.functions(), [&](const Function &F) {
74-
if (F.getCallingConv() == CallingConv::SPIR_KERNEL &&
75-
F.hasFnAttribute(WORK_GROUP_STATIC_ATTR)) {
74+
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
7675
int ArgPos = GetArgumentPos(F);
77-
SPIRKernelNames.emplace_back(F.getName(), ArgPos);
76+
if (ArgPos >= 0)
77+
SPIRKernelNames.emplace_back(F.getName(), ArgPos);
7878
}
7979
});
8080
}
@@ -184,11 +184,34 @@ lowerDynamicLocalMemCallDirect(CallInst *CI, Triple TT,
184184

185185
static void lowerLocalMemCall(Function *LocalMemAllocFunc,
186186
std::function<void(CallInst *CI)> TransformCall) {
187+
static SmallPtrSet<Function *, 16> FuncsCache;
187188
SmallVector<CallInst *, 4> DelCalls;
188189
for (User *U : LocalMemAllocFunc->users()) {
189190
auto *CI = cast<CallInst>(U);
190191
TransformCall(CI);
191192
DelCalls.push_back(CI);
193+
if (!FuncsCache.insert(CI->getFunction()).second)
194+
continue; // We have already traversed call graph from this function
195+
196+
SmallVector<Function *, 8> WorkList;
197+
WorkList.push_back(CI->getFunction());
198+
while (!WorkList.empty()) {
199+
auto *F = WorkList.back();
200+
WorkList.pop_back();
201+
202+
// Mark kernel as using scrach memory if it isn't marked already
203+
if (F->getCallingConv() == CallingConv::SPIR_KERNEL &&
204+
!F->hasFnAttribute(WORK_GROUP_STATIC_ATTR))
205+
F->addFnAttr(WORK_GROUP_STATIC_ATTR);
206+
207+
for (auto *FU : F->users()) {
208+
if (auto *UCI = dyn_cast<CallInst>(FU)) {
209+
if (FuncsCache.insert(UCI->getFunction()).second)
210+
WorkList.push_back(UCI->getFunction());
211+
} // Even though there could be other uses of a Function, we don't
212+
// care about them because we are only concerned about call graph
213+
}
214+
}
192215
}
193216

194217
for (auto *CI : DelCalls) {

0 commit comments

Comments
 (0)