@@ -360,45 +360,48 @@ def format_of(ty):
360360 ptr_info.valid = false;
361361 return ptr_info;
362362}}
363+
363364// start sycl
364365template <class T>
365366static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
366- cgh.set_arg(index, *static_cast<const T *>(value));
367+ cgh.set_arg(index, *static_cast<const T *>(value));
367368}}
369+
368370static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr { ', ' + arg_decls if len (arg_decls ) > 0 else '' } ) {{
369371
370- std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
371- void *params[] = {{ { ', ' .join (f"&arg{ i } " for i , ty in signature .items () if i not in constants and ty != "none" )} }};
372- uint32_t num_params = sizeof(params)/sizeof(params[0]);
373- uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
374- size_t global_range_x = gridX*threads_per_warp*num_warps;
375- size_t global_range_y = gridY;
376- size_t global_range_z = gridZ;
377- size_t local_range_x = num_warps*threads_per_warp;
378- size_t local_range_y = 1;
379- size_t local_range_z = 1;
380- sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
381- sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
382- sycl::nd_range<3> parallel_work_size(global_range, local_range);
383- if (shared_memory) {{
384- expected_num_params -= 1;
385- }}
386- assert(num_params == expected_num_params && "number of kernel param not matched");
387- // Submit the imported kernel.
388- auto cgf = [&](sycl::handler &cgh) {{
389- { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if i not in constants and signature [i ] != "none" ]))}
372+ std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
373+ void *params[] = {{ { ', ' .join (f"&arg{ i } " for i , ty in signature .items () if i not in constants and ty != "none" )} }};
374+ uint32_t num_params = sizeof(params)/sizeof(params[0]);
375+ uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
376+ size_t global_range_x = gridX*threads_per_warp*num_warps;
377+ size_t global_range_y = gridY;
378+ size_t global_range_z = gridZ;
379+ size_t local_range_x = num_warps*threads_per_warp;
380+ size_t local_range_y = 1;
381+ size_t local_range_z = 1;
382+ sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
383+ sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
384+ sycl::nd_range<3> parallel_work_size(global_range, local_range);
390385 if (shared_memory) {{
391- using share_mem_t = sycl::local_accessor<int8_t, 1>;
392- share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
393- cgh.set_arg(num_params, local_buffer);
394- cgh.parallel_for(parallel_work_size, kernel_ptr);
395- }} else {{
396- cgh.parallel_for(parallel_work_size, kernel_ptr);
386+ expected_num_params -= 1;
397387 }}
398- }};
399- auto event = stream.submit(cgf);
388+ assert(num_params == expected_num_params && "number of kernel param not matched");
389+ // Submit the imported kernel.
390+ auto cgf = [&](sycl::handler &cgh) {{
391+ { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if i not in constants and signature [i ] != "none" ]))}
392+ if (shared_memory) {{
393+ using share_mem_t = sycl::local_accessor<int8_t, 1>;
394+ share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
395+ cgh.set_arg(num_params, local_buffer);
396+ cgh.parallel_for(parallel_work_size, kernel_ptr);
397+ }} else {{
398+ cgh.parallel_for(parallel_work_size, kernel_ptr);
399+ }}
400+ }};
401+ auto event = stream.submit(cgf);
400402}}
401403// end sycl
404+
402405static PyObject* launch(PyObject* self, PyObject* args) {{
403406
404407 int gridX, gridY, gridZ;
0 commit comments