@@ -6666,6 +6666,34 @@ class FreeFunctionPrinter {
66666666 FD->getTemplateSpecializationArgs ());
66676667 }
66686668
6669+ // / Emits free function kernel info specialization for shimN.
6670+ // / \param ShimCounter The counter for the shim function.
6671+ // / \param KParamsSize The number of kernel free function arguments.
6672+ // / \param KName The name of the kernel free function.
6673+ void printFreeFunctionKernelInfo (const unsigned ShimCounter,
6674+ const size_t KParamsSize,
6675+ std::string_view KName) {
6676+ O << " \n " ;
6677+ O << " namespace sycl {\n " ;
6678+ O << " inline namespace _V1 {\n " ;
6679+ O << " namespace detail {\n " ;
6680+ O << " //Free Function Kernel info specialization for shim" << ShimCounter
6681+ << " \n " ;
6682+ O << " template <> struct FreeFunctionInfoData<__sycl_shim" << ShimCounter
6683+ << " ()> {\n " ;
6684+ O << " __SYCL_DLL_LOCAL\n " ;
6685+ O << " static constexpr unsigned getNumParams() { return " << KParamsSize
6686+ << " ; }\n " ;
6687+ O << " __SYCL_DLL_LOCAL\n " ;
6688+ O << " static constexpr const char *getFunctionName() { return " ;
6689+ O << " \" " << KName << " \" ; }\n " ;
6690+ O << " };\n " ;
6691+ O << " } // namespace detail\n "
6692+ << " } // namespace _V1\n "
6693+ << " } // namespace sycl\n " ;
6694+ O << " \n " ;
6695+ }
6696+
66696697private:
66706698 // / Helper method to get string with template types
66716699 // / \param TAL The template argument list.
@@ -6915,6 +6943,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
69156943 O << " \"\" ,\n " ;
69166944 O << " };\n\n " ;
69176945
6946+ O << " static constexpr unsigned kernel_args_sizes[] = {" ;
6947+ for (unsigned I = 0 ; I < KernelDescs.size (); I++) {
6948+ O << KernelDescs[I].Params .size () << " , " ;
6949+ }
6950+ O << " };\n\n " ;
69186951 O << " // array representing signatures of all kernels defined in the\n " ;
69196952 O << " // corresponding source\n " ;
69206953 O << " static constexpr\n " ;
@@ -7127,6 +7160,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
71277160 FFPrinter.printFreeFunctionShim (K.SyclKernel , ShimCounter, ParmList);
71287161 O << " ;\n " ;
71297162 O << " }\n " ;
7163+ FFPrinter.printFreeFunctionKernelInfo (ShimCounter, K.Params .size (), K.Name );
71307164 Policy.SuppressDefaultTemplateArgs = true ;
71317165 Policy.EnforceDefaultTemplateArgs = false ;
71327166
@@ -7156,22 +7190,21 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
71567190
71577191 if (FreeFunctionCount > 0 ) {
71587192 O << " \n #include <sycl/kernel_bundle.hpp>\n " ;
7159- }
7160- ShimCounter = 1 ;
7161- for (const KernelDesc &K : KernelDescs) {
7162- if (!S.isFreeFunction (K.SyclKernel ))
7163- continue ;
7164-
7165- O << " \n // Definition of kernel_id of " << K.Name << " \n " ;
7193+ O << " #include <sycl/detail/kernel_global_info.hpp>\n " ;
71667194 O << " namespace sycl {\n " ;
7167- O << " template <>\n " ;
7168- O << " inline kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
7169- << ShimCounter << " ()>() {\n " ;
7170- O << " return sycl::detail::get_kernel_id_impl(std::string_view{\" "
7171- << K.Name << " \" });\n " ;
7172- O << " }\n " ;
7173- O << " }\n " ;
7174- ++ShimCounter;
7195+ O << " inline namespace _V1 {\n " ;
7196+ O << " namespace detail {\n " ;
7197+ O << " struct GlobalMapUpdater {\n " ;
7198+ O << " GlobalMapUpdater() {\n " ;
7199+ O << " sycl::detail::free_function_info_map::add("
7200+ << " sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, "
7201+ << KernelDescs.size () << " );\n " ;
7202+ O << " }\n " ;
7203+ O << " };\n " ;
7204+ O << " static GlobalMapUpdater updater;\n " ;
7205+ O << " } // namespace detail\n " ;
7206+ O << " } // namespace _V1\n " ;
7207+ O << " } // namespace sycl\n " ;
71757208 }
71767209}
71777210
0 commit comments