|
21 | 21 | rename_labels,
|
22 | 22 | replace_var_names,
|
23 | 23 | )
|
| 24 | +from numba.core.target_extension import target_override |
24 | 25 | from numba.core.typing import signature
|
25 | 26 | from numba.parfors import parfor
|
26 | 27 |
|
27 | 28 | from numba_dpex.core import config
|
| 29 | +from numba_dpex.core.types.kernel_api.index_space_ids import ItemType |
28 | 30 | from numba_dpex.kernel_api_impl.spirv import spirv_generator
|
29 | 31 |
|
30 | 32 | from ..descriptor import dpex_kernel_target
|
@@ -66,18 +68,18 @@ def _print_body(body_dict):
|
66 | 68 | def _compile_kernel_parfor(
|
67 | 69 | sycl_queue, kernel_name, func_ir, argtypes, debug=False
|
68 | 70 | ):
|
69 |
| - |
70 |
| - cres = compile_numba_ir_with_dpex( |
71 |
| - pyfunc=func_ir, |
72 |
| - pyfunc_name=kernel_name, |
73 |
| - args=argtypes, |
74 |
| - return_type=None, |
75 |
| - debug=debug, |
76 |
| - is_kernel=True, |
77 |
| - typing_context=dpex_kernel_target.typing_context, |
78 |
| - target_context=dpex_kernel_target.target_context, |
79 |
| - extra_compile_flags=None, |
80 |
| - ) |
| 71 | + with target_override(dpex_kernel_target.target_context.target_name): |
| 72 | + cres = compile_numba_ir_with_dpex( |
| 73 | + pyfunc=func_ir, |
| 74 | + pyfunc_name=kernel_name, |
| 75 | + args=argtypes, |
| 76 | + return_type=None, |
| 77 | + debug=debug, |
| 78 | + is_kernel=True, |
| 79 | + typing_context=dpex_kernel_target.typing_context, |
| 80 | + target_context=dpex_kernel_target.target_context, |
| 81 | + extra_compile_flags=None, |
| 82 | + ) |
81 | 83 | cres.library.inline_threshold = config.INLINE_THRESHOLD
|
82 | 84 | cres.library._optimize_final_module()
|
83 | 85 | func = cres.library.get_function(cres.fndesc.llvm_func_name)
|
@@ -420,6 +422,13 @@ def create_kernel_for_parfor(
|
420 | 422 | print("kernel_ir after remove dead")
|
421 | 423 | kernel_ir.dump()
|
422 | 424 |
|
| 425 | + # The first argument to a range kernel is a kernel_api.Item object. The |
| 426 | + # ``Item`` object is used by the kernel_api.spirv backend to generate the |
| 427 | + # correct SPIR-V indexing instructions. Since, the argument is not something |
| 428 | + # available originally in the kernel_param_types, we add it at this point to |
| 429 | + # make sure the kernel signature matches the actual generated code. |
| 430 | + ty_item = ItemType(parfor_dim) |
| 431 | + kernel_param_types = (ty_item, *kernel_param_types) |
423 | 432 | kernel_sig = signature(types.none, *kernel_param_types)
|
424 | 433 |
|
425 | 434 | if config.DEBUG_ARRAY_OPT:
|
|
0 commit comments