Skip to content

Commit 41b5d25

Browse files
committed
Add dedicated parfor injection pass for kernels
1 parent eb0acbb commit 41b5d25

File tree

10 files changed

+217
-428
lines changed

10 files changed

+217
-428
lines changed

numba_dpex/core/descriptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class DpexTargetOptions(CPUTargetOptions):
4848
no_compile = _option_mapping("no_compile")
4949
inline_threshold = _option_mapping("inline_threshold")
5050
_compilation_mode = _option_mapping("_compilation_mode")
51+
# TODO: create separate parfor kernel target
52+
_parfor_body_args = _option_mapping("_parfor_body_args")
5153

5254
def finalize(self, flags, options):
5355
super().finalize(flags, options)
@@ -63,6 +65,7 @@ def finalize(self, flags, options):
6365
_inherit_if_not_set(
6466
flags, options, "_compilation_mode", CompilationMode.KERNEL
6567
)
68+
_inherit_if_not_set(flags, options, "_parfor_body_args", None)
6669

6770

6871
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 25 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@
2626
from numba.parfors import parfor
2727

2828
from numba_dpex.core import config
29+
from numba_dpex.core.decorators import kernel
30+
from numba_dpex.core.parfors.parfor_sentinel_replace_pass import (
31+
ParforBodyArguments,
32+
)
2933
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
34+
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
3035
from numba_dpex.kernel_api_impl.spirv import spirv_generator
36+
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
37+
SPIRVKernelDispatcher,
38+
_SPIRVKernelCompileResult,
39+
)
3140

3241
from ..descriptor import dpex_kernel_target
3342
from ..types import DpnpNdArray
@@ -38,79 +47,19 @@
3847
class ParforKernel:
3948
def __init__(
4049
self,
41-
name,
42-
kernel,
4350
signature,
4451
kernel_args,
4552
kernel_arg_types,
46-
queue: dpctl.SyclQueue,
4753
local_accessors=None,
4854
work_group_size=None,
55+
kernel_module=None,
4956
):
50-
self.name = name
51-
self.kernel = kernel
5257
self.signature = signature
5358
self.kernel_args = kernel_args
5459
self.kernel_arg_types = kernel_arg_types
55-
self.queue = queue
5660
self.local_accessors = local_accessors
5761
self.work_group_size = work_group_size
58-
59-
60-
def _print_block(block):
61-
for i, inst in enumerate(block.body):
62-
print(" ", i, inst)
63-
64-
65-
def _print_body(body_dict):
66-
"""Pretty-print a set of IR blocks."""
67-
for label, block in body_dict.items():
68-
print("label: ", label)
69-
_print_block(block)
70-
71-
72-
def _compile_kernel_parfor(
73-
sycl_queue, kernel_name, func_ir, argtypes, debug=False
74-
):
75-
with target_override(dpex_kernel_target.target_context.target_name):
76-
cres = compile_numba_ir_with_dpex(
77-
pyfunc=func_ir,
78-
pyfunc_name=kernel_name,
79-
args=argtypes,
80-
return_type=None,
81-
debug=debug,
82-
is_kernel=True,
83-
typing_context=dpex_kernel_target.typing_context,
84-
target_context=dpex_kernel_target.target_context,
85-
extra_compile_flags=None,
86-
)
87-
cres.library.inline_threshold = config.INLINE_THRESHOLD
88-
cres.library._optimize_final_module()
89-
func = cres.library.get_function(cres.fndesc.llvm_func_name)
90-
kernel = dpex_kernel_target.target_context.prepare_spir_kernel(
91-
func, cres.signature.args
92-
)
93-
spirv_module = spirv_generator.llvm_to_spirv(
94-
dpex_kernel_target.target_context,
95-
kernel.module.__str__(),
96-
kernel.module.as_bitcode(),
97-
)
98-
99-
dpctl_create_program_from_spirv_flags = []
100-
if debug or config.DPEX_OPT == 0:
101-
# if debug is ON we need to pass additional flags to igc.
102-
dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"]
103-
104-
# create a sycl::kernel_bundle
105-
kernel_bundle = dpctl_prog.create_program_from_spirv(
106-
sycl_queue,
107-
spirv_module,
108-
" ".join(dpctl_create_program_from_spirv_flags),
109-
)
110-
# create a sycl::kernel
111-
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name)
112-
113-
return sycl_kernel
62+
self.kernel_module = kernel_module
11463

11564

11665
def _legalize_names_with_typemap(names, typemap):
@@ -189,76 +138,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
189138
typemap[v] = types.npytypes.Array(el_typ, 1, "C")
190139

191140

192-
def _find_setitems_block(setitems, block, typemap):
193-
for inst in block.body:
194-
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
195-
setitems.add(inst.target.name)
196-
elif isinstance(inst, parfor.Parfor):
197-
_find_setitems_block(setitems, inst.init_block, typemap)
198-
_find_setitems_body(setitems, inst.loop_body, typemap)
199-
200-
201-
def _find_setitems_body(setitems, loop_body, typemap):
202-
"""
203-
Find the arrays that are written into (goes into setitems)
204-
"""
205-
for label, block in loop_body.items():
206-
_find_setitems_block(setitems, block, typemap)
207-
208-
209-
def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body):
210-
# new label for splitting sentinel block
211-
new_label = max(loop_body.keys()) + 1
212-
213-
# Search all the block in the kernel function for the sentinel assignment.
214-
for label, block in kernel_ir.blocks.items():
215-
for i, inst in enumerate(block.body):
216-
if (
217-
isinstance(inst, ir.Assign)
218-
and inst.target.name == sentinel_name
219-
):
220-
# We found the sentinel assignment.
221-
loc = inst.loc
222-
scope = block.scope
223-
# split block across __sentinel__
224-
# A new block is allocated for the statements prior to the
225-
# sentinel but the new block maintains the current block label.
226-
prev_block = ir.Block(scope, loc)
227-
prev_block.body = block.body[:i]
228-
229-
# The current block is used for statements after the sentinel.
230-
block.body = block.body[i + 1 :] # noqa: E203
231-
# But the current block gets a new label.
232-
body_first_label = min(loop_body.keys())
233-
234-
# The previous block jumps to the minimum labelled block of the
235-
# parfor body.
236-
prev_block.append(ir.Jump(body_first_label, loc))
237-
# Add all the parfor loop body blocks to the kernel function's
238-
# IR.
239-
for loop, b in loop_body.items():
240-
kernel_ir.blocks[loop] = b
241-
body_last_label = max(loop_body.keys())
242-
kernel_ir.blocks[new_label] = block
243-
kernel_ir.blocks[label] = prev_block
244-
# Add a jump from the last parfor body block to the block
245-
# containing statements after the sentinel.
246-
kernel_ir.blocks[body_last_label].append(
247-
ir.Jump(new_label, loc)
248-
)
249-
break
250-
else:
251-
continue
252-
break
253-
254-
255141
def create_kernel_for_parfor(
256142
lowerer,
257143
parfor_node,
258144
typemap,
259-
flags,
260145
loop_ranges,
261-
has_aliases,
262146
races,
263147
parfor_outputs,
264148
) -> ParforKernel:
@@ -367,120 +251,38 @@ def create_kernel_for_parfor(
367251
loop_ranges=loop_ranges,
368252
param_dict=param_dict,
369253
)
370-
kernel_ir = kernel_template.kernel_ir
371254

372-
if config.DEBUG_ARRAY_OPT:
373-
print("kernel_ir dump ", type(kernel_ir))
374-
kernel_ir.dump()
375-
print("loop_body dump ", type(loop_body))
376-
_print_body(loop_body)
377-
378-
# rename all variables in kernel_ir afresh
379-
var_table = get_name_var_table(kernel_ir.blocks)
380-
new_var_dict = {}
381-
reserved_names = (
382-
[sentinel_name] + list(param_dict.values()) + legal_loop_indices
255+
kernel_dispatcher: SPIRVKernelDispatcher = kernel(
256+
kernel_template.py_func,
257+
_parfor_body_args=ParforBodyArguments(
258+
loop_body=loop_body,
259+
param_dict=param_dict,
260+
legal_loop_indices=legal_loop_indices,
261+
),
383262
)
384-
for name, var in var_table.items():
385-
if not (name in reserved_names):
386-
new_var_dict[name] = mk_unique_var(name)
387-
replace_var_names(kernel_ir.blocks, new_var_dict)
388-
if config.DEBUG_ARRAY_OPT:
389-
print("kernel_ir dump after renaming ")
390-
kernel_ir.dump()
391-
392-
kernel_param_types = param_types
393263

394-
if config.DEBUG_ARRAY_OPT:
395-
print(
396-
"kernel_param_types = ",
397-
type(kernel_param_types),
398-
"\n",
399-
kernel_param_types,
400-
)
401-
402-
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1
403-
404-
# Add kernel stub last label to each parfor.loop_body label to prevent
405-
# label conflicts.
406-
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)
407-
408-
_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)
409-
410-
if config.DEBUG_ARRAY_OPT:
411-
print("kernel_ir last dump before renaming")
412-
kernel_ir.dump()
413-
414-
kernel_ir.blocks = rename_labels(kernel_ir.blocks)
415-
remove_dels(kernel_ir.blocks)
416-
417-
old_alias = flags.noalias
418-
if not has_aliases:
419-
if config.DEBUG_ARRAY_OPT:
420-
print("No aliases found so adding noalias flag.")
421-
flags.noalias = True
422-
423-
remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)
424-
425-
if config.DEBUG_ARRAY_OPT:
426-
print("kernel_ir after remove dead")
427-
kernel_ir.dump()
428-
429-
# The first argument to a range kernel is a kernel_api.Item object. The
430-
# ``Item`` object is used by the kernel_api.spirv backend to generate the
264+
# The first argument to a range kernel is a kernel_api.NdItem object. The
265+
# ``NdItem`` object is used by the kernel_api.spirv backend to generate the
431266
# correct SPIR-V indexing instructions. Since, the argument is not something
432267
# available originally in the kernel_param_types, we add it at this point to
433268
# make sure the kernel signature matches the actual generated code.
434269
ty_item = ItemType(parfor_dim)
435-
kernel_param_types = (ty_item, *kernel_param_types)
270+
kernel_param_types = (ty_item, *param_types)
436271
kernel_sig = signature(types.none, *kernel_param_types)
437272

438-
if config.DEBUG_ARRAY_OPT:
439-
sys.stdout.flush()
440-
441-
if config.DEBUG_ARRAY_OPT:
442-
print("after DUFunc inline".center(80, "-"))
443-
kernel_ir.dump()
444-
445-
# The ParforLegalizeCFD pass has already ensured that the LHS and RHS
446-
# arrays are on same device. We can take the queue from the first input
447-
# array and use that to compile the kernel.
448-
449-
exec_queue: dpctl.SyclQueue = None
450-
451-
for arg in parfor_args:
452-
obj = typemap[arg]
453-
if isinstance(obj, DpnpNdArray):
454-
filter_string = obj.queue.sycl_device
455-
# FIXME: A better design is required so that we do not have to
456-
# create a queue every time.
457-
exec_queue = dpctl.get_device_cached_queue(filter_string)
458-
459-
if not exec_queue:
460-
raise AssertionError(
461-
"No execution found for parfor. No way to compile the kernel!"
462-
)
463-
464-
sycl_kernel = _compile_kernel_parfor(
465-
exec_queue,
466-
kernel_name,
467-
kernel_ir,
468-
kernel_param_types,
469-
debug=flags.debuginfo,
273+
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
274+
types.void(*kernel_param_types) # kernel signature
470275
)
471-
472-
flags.noalias = old_alias
276+
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module
473277

474278
if config.DEBUG_ARRAY_OPT:
475279
print("kernel_sig = ", kernel_sig)
476280

477281
return ParforKernel(
478-
name=kernel_name,
479-
kernel=sycl_kernel,
480282
signature=kernel_sig,
481283
kernel_args=parfor_args,
482284
kernel_arg_types=func_arg_types,
483-
queue=exec_queue,
285+
kernel_module=kernel_module,
484286
)
485287

486288

numba_dpex/core/parfors/kernel_templates/kernel_template_iface.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ def _generate_kernel_ir(self):
3030
def dump_kernel_string(self):
3131
raise NotImplementedError
3232

33-
@abc.abstractmethod
34-
def dump_kernel_ir(self):
35-
raise NotImplementedError
36-
3733
@property
3834
@abc.abstractmethod
39-
def kernel_ir(self):
35+
def py_func(self):
4036
raise NotImplementedError
4137

4238
@property

numba_dpex/core/parfors/kernel_templates/range_kernel_template.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys
66

77
import dpnp
8-
from numba.core import compiler
98

109
import numba_dpex as dpex
1110

@@ -51,7 +50,7 @@ def __init__(
5150
self._param_dict = param_dict
5251

5352
self._kernel_txt = self._generate_kernel_stub_as_string()
54-
self._kernel_ir = self._generate_kernel_ir()
53+
self._py_func = self._generate_kernel_ir()
5554

5655
def _generate_kernel_stub_as_string(self):
5756
"""Generates a stub dpex kernel for the parfor as a string.
@@ -109,17 +108,15 @@ def _generate_kernel_ir(self):
109108
globls = {"dpnp": dpnp, "dpex": dpex}
110109
locls = {}
111110
exec(self._kernel_txt, globls, locls)
112-
kernel_fn = locls[self._kernel_name]
113-
114-
return compiler.run_frontend(kernel_fn)
111+
return locls[self._kernel_name]
115112

116113
@property
117-
def kernel_ir(self):
118-
"""Returns the Numba IR generated for a RangeKernelTemplate.
119-
120-
Returns: The Numba functionIR object for the compiled kernel_txt string.
114+
def py_func(self):
115+
"""Returns the python function generated for a
116+
TreeReduceIntermediateKernelTemplate.
117+
Returns: The python function object for the compiled kernel_txt string.
121118
"""
122-
return self._kernel_ir
119+
return self._py_func
123120

124121
@property
125122
def kernel_string(self):
@@ -134,7 +131,3 @@ def dump_kernel_string(self):
134131
"""Helper to print the kernel function string."""
135132
print(self._kernel_txt)
136133
sys.stdout.flush()
137-
138-
def dump_kernel_ir(self):
139-
"""Helper to dump the Numba IR for the RangeKernelTemplate."""
140-
self._kernel_ir.dump()

0 commit comments

Comments
 (0)