Skip to content

Commit 464871d

Browse files
[SYCL] Kernel free function num args functionality (#19517)
This PR adds possibility to get kernel free function number of arguments according to [docs](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc#behavior-with-kernel-bundle-functions-in-the-core-sycl-specification) Tests were added according to test plan.
1 parent b41ef41 commit 464871d

27 files changed

+1332
-311
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6666,6 +6666,34 @@ class FreeFunctionPrinter {
66666666
FD->getTemplateSpecializationArgs());
66676667
}
66686668

6669+
/// Emits free function kernel info specialization for shimN.
6670+
/// \param ShimCounter The counter for the shim function.
6671+
/// \param KParamsSize The number of kernel free function arguments.
6672+
/// \param KName The name of the kernel free function.
6673+
void printFreeFunctionKernelInfo(const unsigned ShimCounter,
6674+
const size_t KParamsSize,
6675+
std::string_view KName) {
6676+
O << "\n";
6677+
O << "namespace sycl {\n";
6678+
O << "inline namespace _V1 {\n";
6679+
O << "namespace detail {\n";
6680+
O << "//Free Function Kernel info specialization for shim" << ShimCounter
6681+
<< "\n";
6682+
O << "template <> struct FreeFunctionInfoData<__sycl_shim" << ShimCounter
6683+
<< "()> {\n";
6684+
O << " __SYCL_DLL_LOCAL\n";
6685+
O << " static constexpr unsigned getNumParams() { return " << KParamsSize
6686+
<< "; }\n";
6687+
O << " __SYCL_DLL_LOCAL\n";
6688+
O << " static constexpr const char *getFunctionName() { return ";
6689+
O << "\"" << KName << "\"; }\n";
6690+
O << "};\n";
6691+
O << "} // namespace detail\n"
6692+
<< "} // namespace _V1\n"
6693+
<< "} // namespace sycl\n";
6694+
O << "\n";
6695+
}
6696+
66696697
private:
66706698
/// Helper method to get string with template types
66716699
/// \param TAL The template argument list.
@@ -6915,6 +6943,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
69156943
O << " \"\",\n";
69166944
O << "};\n\n";
69176945

6946+
O << "static constexpr unsigned kernel_args_sizes[] = {";
6947+
for (unsigned I = 0; I < KernelDescs.size(); I++) {
6948+
O << KernelDescs[I].Params.size() << ", ";
6949+
}
6950+
O << "};\n\n";
69186951
O << "// array representing signatures of all kernels defined in the\n";
69196952
O << "// corresponding source\n";
69206953
O << "static constexpr\n";
@@ -7127,6 +7160,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
71277160
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
71287161
O << ";\n";
71297162
O << "}\n";
7163+
FFPrinter.printFreeFunctionKernelInfo(ShimCounter, K.Params.size(), K.Name);
71307164
Policy.SuppressDefaultTemplateArgs = true;
71317165
Policy.EnforceDefaultTemplateArgs = false;
71327166

@@ -7156,22 +7190,21 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
71567190

71577191
if (FreeFunctionCount > 0) {
71587192
O << "\n#include <sycl/kernel_bundle.hpp>\n";
7159-
}
7160-
ShimCounter = 1;
7161-
for (const KernelDesc &K : KernelDescs) {
7162-
if (!S.isFreeFunction(K.SyclKernel))
7163-
continue;
7164-
7165-
O << "\n// Definition of kernel_id of " << K.Name << "\n";
7193+
O << "#include <sycl/detail/kernel_global_info.hpp>\n";
71667194
O << "namespace sycl {\n";
7167-
O << "template <>\n";
7168-
O << "inline kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
7169-
<< ShimCounter << "()>() {\n";
7170-
O << " return sycl::detail::get_kernel_id_impl(std::string_view{\""
7171-
<< K.Name << "\"});\n";
7172-
O << "}\n";
7173-
O << "}\n";
7174-
++ShimCounter;
7195+
O << "inline namespace _V1 {\n";
7196+
O << "namespace detail {\n";
7197+
O << "struct GlobalMapUpdater {\n";
7198+
O << " GlobalMapUpdater() {\n";
7199+
O << " sycl::detail::free_function_info_map::add("
7200+
<< "sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, "
7201+
<< KernelDescs.size() << ");\n";
7202+
O << " }\n";
7203+
O << "};\n";
7204+
O << "static GlobalMapUpdater updater;\n";
7205+
O << "} // namespace detail\n";
7206+
O << "} // namespace _V1\n";
7207+
O << "} // namespace sycl\n";
71757208
}
71767209
}
71777210

0 commit comments

Comments
 (0)