Skip to content

Commit b8a3979

Browse files
[SYCL][NFC] Extend getKernelNamesUsingAssert to support general function names (#20488)
Currently, we have getKernelNamesUsingAssert to detect all SPIR kernels which use assert functions via BFS. This PR extend it in 2 points: 1. Support passing a general function name to it for detecting all SPIR kernels using it 2. Support passing a group of special function names for searching SPIR kernels using them. We may need to check other special function and can easily call it instead of adding a new getKernelNamesUsingXXX. --------- Signed-off-by: jinge90 <[email protected]> Co-authored-by: Marcos Maronas <[email protected]>
1 parent bab58c1 commit b8a3979

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ bool isModuleUsingTsan(const Module &M) {
6666
// Optional.
6767
// Otherwise, it returns an Optional containing a list of reached
6868
// SPIR kernel function's names.
69-
static std::optional<std::vector<StringRef>>
70-
traverseCGToFindSPIRKernels(const Function *StartingFunction) {
69+
static std::optional<std::vector<StringRef>> traverseCGToFindSPIRKernels(
70+
const std::vector<Function *> &StartingFunctionVec) {
7171
std::queue<const Function *> FunctionsToVisit;
7272
std::unordered_set<const Function *> VisitedFunctions;
73-
FunctionsToVisit.push(StartingFunction);
73+
for (const Function *FPtr : StartingFunctionVec)
74+
FunctionsToVisit.push(FPtr);
7475
std::vector<StringRef> KernelNames;
7576

7677
while (!FunctionsToVisit.empty()) {
@@ -106,13 +107,20 @@ traverseCGToFindSPIRKernels(const Function *StartingFunction) {
106107
return {std::move(KernelNames)};
107108
}
108109

109-
static std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
110-
auto *DevicelibAssertFailFunction = M.getFunction("__devicelib_assert_fail");
111-
if (!DevicelibAssertFailFunction)
110+
static std::vector<StringRef>
111+
getKernelNamesUsingSpecialFunctions(const Module &M,
112+
const std::vector<StringRef> &FNames) {
113+
std::vector<Function *> SpecialFunctionVec;
114+
for (const auto Fn : FNames) {
115+
Function *FPtr = M.getFunction(Fn);
116+
if (FPtr)
117+
SpecialFunctionVec.push_back(FPtr);
118+
}
119+
120+
if (SpecialFunctionVec.size() == 0)
112121
return {};
113122

114-
auto TraverseResult =
115-
traverseCGToFindSPIRKernels(DevicelibAssertFailFunction);
123+
auto TraverseResult = traverseCGToFindSPIRKernels(SpecialFunctionVec);
116124

117125
if (TraverseResult.has_value())
118126
return std::move(*TraverseResult);
@@ -442,7 +450,9 @@ PropSetRegTy computeModuleProperties(const Module &M,
442450
PropSet.add(PropSetRegTy::SYCL_MISC_PROP, "optLevel", OptLevel);
443451
}
444452
{
445-
std::vector<StringRef> FuncNames = getKernelNamesUsingAssert(M);
453+
std::vector<StringRef> AssertFuncNames{"__devicelib_assert_fail"};
454+
std::vector<StringRef> FuncNames =
455+
getKernelNamesUsingSpecialFunctions(M, AssertFuncNames);
446456
for (const StringRef &FName : FuncNames)
447457
PropSet.add(PropSetRegTy::SYCL_ASSERT_USED, FName, true);
448458
}

0 commit comments

Comments
 (0)