Skip to content

Commit 7212c4a

Browse files
committed
fixes
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8caeb60 commit 7212c4a

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

third_party/intel/backend/driver.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -360,45 +360,48 @@ def format_of(ty):
360360
ptr_info.valid = false;
361361
return ptr_info;
362362
}}
363+
363364
// start sycl
364365
template <class T>
365366
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
366-
cgh.set_arg(index, *static_cast<const T *>(value));
367+
cgh.set_arg(index, *static_cast<const T *>(value));
367368
}}
369+
368370
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
369371
370-
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
371-
void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
372-
uint32_t num_params = sizeof(params)/sizeof(params[0]);
373-
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
374-
size_t global_range_x = gridX*threads_per_warp*num_warps;
375-
size_t global_range_y = gridY;
376-
size_t global_range_z = gridZ;
377-
size_t local_range_x = num_warps*threads_per_warp;
378-
size_t local_range_y = 1;
379-
size_t local_range_z = 1;
380-
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
381-
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
382-
sycl::nd_range<3> parallel_work_size(global_range, local_range);
383-
if (shared_memory) {{
384-
expected_num_params -= 1;
385-
}}
386-
assert(num_params == expected_num_params && "number of kernel param not matched");
387-
// Submit the imported kernel.
388-
auto cgf = [&](sycl::handler &cgh) {{
389-
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))}
372+
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
373+
void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
374+
uint32_t num_params = sizeof(params)/sizeof(params[0]);
375+
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
376+
size_t global_range_x = gridX*threads_per_warp*num_warps;
377+
size_t global_range_y = gridY;
378+
size_t global_range_z = gridZ;
379+
size_t local_range_x = num_warps*threads_per_warp;
380+
size_t local_range_y = 1;
381+
size_t local_range_z = 1;
382+
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
383+
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
384+
sycl::nd_range<3> parallel_work_size(global_range, local_range);
390385
if (shared_memory) {{
391-
using share_mem_t = sycl::local_accessor<int8_t, 1>;
392-
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
393-
cgh.set_arg(num_params, local_buffer);
394-
cgh.parallel_for(parallel_work_size, kernel_ptr);
395-
}} else {{
396-
cgh.parallel_for(parallel_work_size, kernel_ptr);
386+
expected_num_params -= 1;
397387
}}
398-
}};
399-
auto event = stream.submit(cgf);
388+
assert(num_params == expected_num_params && "number of kernel param not matched");
389+
// Submit the imported kernel.
390+
auto cgf = [&](sycl::handler &cgh) {{
391+
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))}
392+
if (shared_memory) {{
393+
using share_mem_t = sycl::local_accessor<int8_t, 1>;
394+
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
395+
cgh.set_arg(num_params, local_buffer);
396+
cgh.parallel_for(parallel_work_size, kernel_ptr);
397+
}} else {{
398+
cgh.parallel_for(parallel_work_size, kernel_ptr);
399+
}}
400+
}};
401+
auto event = stream.submit(cgf);
400402
}}
401403
// end sycl
404+
402405
static PyObject* launch(PyObject* self, PyObject* args) {{
403406
404407
int gridX, gridY, gridZ;

0 commit comments

Comments
 (0)