Skip to content

Commit 2e4e4cb

Browse files
authored
Merge pull request #1416 from IntelPython/port/parfor_kernel_templates_to_new_API
Port the parfor range kernel template to new API.
2 parents 5ff654f + 09a1cca commit 2e4e4cb

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

numba_dpex/core/descriptor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def typing_context(self):
9696
"""
9797
return self._toplevel_typing_context
9898

99+
@property
100+
def target_name(self):
101+
return self._target_name
102+
99103

100104
class DpexTarget(TargetDescriptor):
101105
"""

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
rename_labels,
2222
replace_var_names,
2323
)
24+
from numba.core.target_extension import target_override
2425
from numba.core.typing import signature
2526
from numba.parfors import parfor
2627

2728
from numba_dpex.core import config
29+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
2830
from numba_dpex.kernel_api_impl.spirv import spirv_generator
2931

3032
from ..descriptor import dpex_kernel_target
@@ -66,18 +68,18 @@ def _print_body(body_dict):
6668
def _compile_kernel_parfor(
6769
sycl_queue, kernel_name, func_ir, argtypes, debug=False
6870
):
69-
70-
cres = compile_numba_ir_with_dpex(
71-
pyfunc=func_ir,
72-
pyfunc_name=kernel_name,
73-
args=argtypes,
74-
return_type=None,
75-
debug=debug,
76-
is_kernel=True,
77-
typing_context=dpex_kernel_target.typing_context,
78-
target_context=dpex_kernel_target.target_context,
79-
extra_compile_flags=None,
80-
)
71+
with target_override(dpex_kernel_target.target_context.target_name):
72+
cres = compile_numba_ir_with_dpex(
73+
pyfunc=func_ir,
74+
pyfunc_name=kernel_name,
75+
args=argtypes,
76+
return_type=None,
77+
debug=debug,
78+
is_kernel=True,
79+
typing_context=dpex_kernel_target.typing_context,
80+
target_context=dpex_kernel_target.target_context,
81+
extra_compile_flags=None,
82+
)
8183
cres.library.inline_threshold = config.INLINE_THRESHOLD
8284
cres.library._optimize_final_module()
8385
func = cres.library.get_function(cres.fndesc.llvm_func_name)
@@ -420,6 +422,13 @@ def create_kernel_for_parfor(
420422
print("kernel_ir after remove dead")
421423
kernel_ir.dump()
422424

425+
# The first argument to a range kernel is a kernel_api.Item object. The
426+
# ``Item`` object is used by the kernel_api.spirv backend to generate the
427+
# correct SPIR-V indexing instructions. Since, the argument is not something
428+
# available originally in the kernel_param_types, we add it at this point to
429+
# make sure the kernel signature matches the actual generated code.
430+
ty_item = ItemType(parfor_dim)
431+
kernel_param_types = (ty_item, *kernel_param_types)
423432
kernel_sig = signature(types.none, *kernel_param_types)
424433

425434
if config.DEBUG_ARRAY_OPT:

numba_dpex/core/parfors/kernel_templates/range_kernel_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def _generate_kernel_stub_as_string(self):
6363

6464
# Create the dpex kernel function.
6565
kernel_txt += "def " + self._kernel_name
66-
kernel_txt += "(" + (", ".join(self._kernel_params)) + "):\n"
66+
kernel_txt += "(item, " + (", ".join(self._kernel_params)) + "):\n"
6767
global_id_dim = 0
6868
for_loop_dim = self._kernel_rank
6969
global_id_dim = self._kernel_rank
7070

7171
for dim in range(global_id_dim):
7272
dimstr = str(dim)
7373
kernel_txt += (
74-
f" {self._ivar_names[dim]} = dpex.get_global_id({dimstr})\n"
74+
f" {self._ivar_names[dim]} = item.get_id({dimstr})\n"
7575
)
7676

7777
for dim in range(global_id_dim, for_loop_dim):

0 commit comments

Comments
 (0)