@@ -6666,6 +6666,34 @@ class FreeFunctionPrinter {
6666
6666
FD->getTemplateSpecializationArgs ());
6667
6667
}
6668
6668
6669
+ // / Emits free function kernel info specialization for shimN.
6670
+ // / \param ShimCounter The counter for the shim function.
6671
+ // / \param KParamsSize The number of kernel free function arguments.
6672
+ // / \param KName The name of the kernel free function.
6673
+ void printFreeFunctionKernelInfo (const unsigned ShimCounter,
6674
+ const size_t KParamsSize,
6675
+ std::string_view KName) {
6676
+ O << " \n " ;
6677
+ O << " namespace sycl {\n " ;
6678
+ O << " inline namespace _V1 {\n " ;
6679
+ O << " namespace detail {\n " ;
6680
+ O << " //Free Function Kernel info specialization for shim" << ShimCounter
6681
+ << " \n " ;
6682
+ O << " template <> struct FreeFunctionInfoData<__sycl_shim" << ShimCounter
6683
+ << " ()> {\n " ;
6684
+ O << " __SYCL_DLL_LOCAL\n " ;
6685
+ O << " static constexpr unsigned getNumParams() { return " << KParamsSize
6686
+ << " ; }\n " ;
6687
+ O << " __SYCL_DLL_LOCAL\n " ;
6688
+ O << " static constexpr const char *getFunctionName() { return " ;
6689
+ O << " \" " << KName << " \" ; }\n " ;
6690
+ O << " };\n " ;
6691
+ O << " } // namespace detail\n "
6692
+ << " } // namespace _V1\n "
6693
+ << " } // namespace sycl\n " ;
6694
+ O << " \n " ;
6695
+ }
6696
+
6669
6697
private:
6670
6698
// / Helper method to get string with template types
6671
6699
// / \param TAL The template argument list.
@@ -6915,6 +6943,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6915
6943
O << " \"\" ,\n " ;
6916
6944
O << " };\n\n " ;
6917
6945
6946
+ O << " static constexpr unsigned kernel_args_sizes[] = {" ;
6947
+ for (unsigned I = 0 ; I < KernelDescs.size (); I++) {
6948
+ O << KernelDescs[I].Params .size () << " , " ;
6949
+ }
6950
+ O << " };\n\n " ;
6918
6951
O << " // array representing signatures of all kernels defined in the\n " ;
6919
6952
O << " // corresponding source\n " ;
6920
6953
O << " static constexpr\n " ;
@@ -7127,6 +7160,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
7127
7160
FFPrinter.printFreeFunctionShim (K.SyclKernel , ShimCounter, ParmList);
7128
7161
O << " ;\n " ;
7129
7162
O << " }\n " ;
7163
+ FFPrinter.printFreeFunctionKernelInfo (ShimCounter, K.Params .size (), K.Name );
7130
7164
Policy.SuppressDefaultTemplateArgs = true ;
7131
7165
Policy.EnforceDefaultTemplateArgs = false ;
7132
7166
@@ -7156,22 +7190,21 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
7156
7190
7157
7191
if (FreeFunctionCount > 0 ) {
7158
7192
O << " \n #include <sycl/kernel_bundle.hpp>\n " ;
7159
- }
7160
- ShimCounter = 1 ;
7161
- for (const KernelDesc &K : KernelDescs) {
7162
- if (!S.isFreeFunction (K.SyclKernel ))
7163
- continue ;
7164
-
7165
- O << " \n // Definition of kernel_id of " << K.Name << " \n " ;
7193
+ O << " #include <sycl/detail/kernel_global_info.hpp>\n " ;
7166
7194
O << " namespace sycl {\n " ;
7167
- O << " template <>\n " ;
7168
- O << " inline kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
7169
- << ShimCounter << " ()>() {\n " ;
7170
- O << " return sycl::detail::get_kernel_id_impl(std::string_view{\" "
7171
- << K.Name << " \" });\n " ;
7172
- O << " }\n " ;
7173
- O << " }\n " ;
7174
- ++ShimCounter;
7195
+ O << " inline namespace _V1 {\n " ;
7196
+ O << " namespace detail {\n " ;
7197
+ O << " struct GlobalMapUpdater {\n " ;
7198
+ O << " GlobalMapUpdater() {\n " ;
7199
+ O << " sycl::detail::free_function_info_map::add("
7200
+ << " sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, "
7201
+ << KernelDescs.size () << " );\n " ;
7202
+ O << " }\n " ;
7203
+ O << " };\n " ;
7204
+ O << " static GlobalMapUpdater updater;\n " ;
7205
+ O << " } // namespace detail\n " ;
7206
+ O << " } // namespace _V1\n " ;
7207
+ O << " } // namespace sycl\n " ;
7175
7208
}
7176
7209
}
7177
7210
0 commit comments