|
26 | 26 | from numba.parfors import parfor
|
27 | 27 |
|
28 | 28 | from numba_dpex.core import config
|
| 29 | +from numba_dpex.core.decorators import kernel |
| 30 | +from numba_dpex.core.parfors.parfor_sentinel_replace_pass import ( |
| 31 | + ParforBodyArguments, |
| 32 | +) |
29 | 33 | from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
|
| 34 | +from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule |
30 | 35 | from numba_dpex.kernel_api_impl.spirv import spirv_generator
|
| 36 | +from numba_dpex.kernel_api_impl.spirv.dispatcher import ( |
| 37 | + SPIRVKernelDispatcher, |
| 38 | + _SPIRVKernelCompileResult, |
| 39 | +) |
31 | 40 |
|
32 | 41 | from ..descriptor import dpex_kernel_target
|
33 | 42 | from ..types import DpnpNdArray
|
|
38 | 47 | class ParforKernel:
|
39 | 48 | def __init__(
|
40 | 49 | self,
|
41 |
| - name, |
42 |
| - kernel, |
43 | 50 | signature,
|
44 | 51 | kernel_args,
|
45 | 52 | kernel_arg_types,
|
46 |
| - queue: dpctl.SyclQueue, |
47 | 53 | local_accessors=None,
|
48 | 54 | work_group_size=None,
|
| 55 | + kernel_module=None, |
49 | 56 | ):
|
50 |
| - self.name = name |
51 |
| - self.kernel = kernel |
52 | 57 | self.signature = signature
|
53 | 58 | self.kernel_args = kernel_args
|
54 | 59 | self.kernel_arg_types = kernel_arg_types
|
55 |
| - self.queue = queue |
56 | 60 | self.local_accessors = local_accessors
|
57 | 61 | self.work_group_size = work_group_size
|
58 |
| - |
59 |
| - |
60 |
| -def _print_block(block): |
61 |
| - for i, inst in enumerate(block.body): |
62 |
| - print(" ", i, inst) |
63 |
| - |
64 |
| - |
65 |
| -def _print_body(body_dict): |
66 |
| - """Pretty-print a set of IR blocks.""" |
67 |
| - for label, block in body_dict.items(): |
68 |
| - print("label: ", label) |
69 |
| - _print_block(block) |
70 |
| - |
71 |
| - |
72 |
| -def _compile_kernel_parfor( |
73 |
| - sycl_queue, kernel_name, func_ir, argtypes, debug=False |
74 |
| -): |
75 |
| - with target_override(dpex_kernel_target.target_context.target_name): |
76 |
| - cres = compile_numba_ir_with_dpex( |
77 |
| - pyfunc=func_ir, |
78 |
| - pyfunc_name=kernel_name, |
79 |
| - args=argtypes, |
80 |
| - return_type=None, |
81 |
| - debug=debug, |
82 |
| - is_kernel=True, |
83 |
| - typing_context=dpex_kernel_target.typing_context, |
84 |
| - target_context=dpex_kernel_target.target_context, |
85 |
| - extra_compile_flags=None, |
86 |
| - ) |
87 |
| - cres.library.inline_threshold = config.INLINE_THRESHOLD |
88 |
| - cres.library._optimize_final_module() |
89 |
| - func = cres.library.get_function(cres.fndesc.llvm_func_name) |
90 |
| - kernel = dpex_kernel_target.target_context.prepare_spir_kernel( |
91 |
| - func, cres.signature.args |
92 |
| - ) |
93 |
| - spirv_module = spirv_generator.llvm_to_spirv( |
94 |
| - dpex_kernel_target.target_context, |
95 |
| - kernel.module.__str__(), |
96 |
| - kernel.module.as_bitcode(), |
97 |
| - ) |
98 |
| - |
99 |
| - dpctl_create_program_from_spirv_flags = [] |
100 |
| - if debug or config.DPEX_OPT == 0: |
101 |
| - # if debug is ON we need to pass additional flags to igc. |
102 |
| - dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"] |
103 |
| - |
104 |
| - # create a sycl::kernel_bundle |
105 |
| - kernel_bundle = dpctl_prog.create_program_from_spirv( |
106 |
| - sycl_queue, |
107 |
| - spirv_module, |
108 |
| - " ".join(dpctl_create_program_from_spirv_flags), |
109 |
| - ) |
110 |
| - # create a sycl::kernel |
111 |
| - sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name) |
112 |
| - |
113 |
| - return sycl_kernel |
| 62 | + self.kernel_module = kernel_module |
114 | 63 |
|
115 | 64 |
|
116 | 65 | def _legalize_names_with_typemap(names, typemap):
|
@@ -189,76 +138,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
|
189 | 138 | typemap[v] = types.npytypes.Array(el_typ, 1, "C")
|
190 | 139 |
|
191 | 140 |
|
192 |
| -def _find_setitems_block(setitems, block, typemap): |
193 |
| - for inst in block.body: |
194 |
| - if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem): |
195 |
| - setitems.add(inst.target.name) |
196 |
| - elif isinstance(inst, parfor.Parfor): |
197 |
| - _find_setitems_block(setitems, inst.init_block, typemap) |
198 |
| - _find_setitems_body(setitems, inst.loop_body, typemap) |
199 |
| - |
200 |
| - |
201 |
| -def _find_setitems_body(setitems, loop_body, typemap): |
202 |
| - """ |
203 |
| - Find the arrays that are written into (goes into setitems) |
204 |
| - """ |
205 |
| - for label, block in loop_body.items(): |
206 |
| - _find_setitems_block(setitems, block, typemap) |
207 |
| - |
208 |
| - |
209 |
| -def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body): |
210 |
| - # new label for splitting sentinel block |
211 |
| - new_label = max(loop_body.keys()) + 1 |
212 |
| - |
213 |
| - # Search all the block in the kernel function for the sentinel assignment. |
214 |
| - for label, block in kernel_ir.blocks.items(): |
215 |
| - for i, inst in enumerate(block.body): |
216 |
| - if ( |
217 |
| - isinstance(inst, ir.Assign) |
218 |
| - and inst.target.name == sentinel_name |
219 |
| - ): |
220 |
| - # We found the sentinel assignment. |
221 |
| - loc = inst.loc |
222 |
| - scope = block.scope |
223 |
| - # split block across __sentinel__ |
224 |
| - # A new block is allocated for the statements prior to the |
225 |
| - # sentinel but the new block maintains the current block label. |
226 |
| - prev_block = ir.Block(scope, loc) |
227 |
| - prev_block.body = block.body[:i] |
228 |
| - |
229 |
| - # The current block is used for statements after the sentinel. |
230 |
| - block.body = block.body[i + 1 :] # noqa: E203 |
231 |
| - # But the current block gets a new label. |
232 |
| - body_first_label = min(loop_body.keys()) |
233 |
| - |
234 |
| - # The previous block jumps to the minimum labelled block of the |
235 |
| - # parfor body. |
236 |
| - prev_block.append(ir.Jump(body_first_label, loc)) |
237 |
| - # Add all the parfor loop body blocks to the kernel function's |
238 |
| - # IR. |
239 |
| - for loop, b in loop_body.items(): |
240 |
| - kernel_ir.blocks[loop] = b |
241 |
| - body_last_label = max(loop_body.keys()) |
242 |
| - kernel_ir.blocks[new_label] = block |
243 |
| - kernel_ir.blocks[label] = prev_block |
244 |
| - # Add a jump from the last parfor body block to the block |
245 |
| - # containing statements after the sentinel. |
246 |
| - kernel_ir.blocks[body_last_label].append( |
247 |
| - ir.Jump(new_label, loc) |
248 |
| - ) |
249 |
| - break |
250 |
| - else: |
251 |
| - continue |
252 |
| - break |
253 |
| - |
254 |
| - |
255 | 141 | def create_kernel_for_parfor(
|
256 | 142 | lowerer,
|
257 | 143 | parfor_node,
|
258 | 144 | typemap,
|
259 |
| - flags, |
260 | 145 | loop_ranges,
|
261 |
| - has_aliases, |
262 | 146 | races,
|
263 | 147 | parfor_outputs,
|
264 | 148 | ) -> ParforKernel:
|
@@ -367,120 +251,38 @@ def create_kernel_for_parfor(
|
367 | 251 | loop_ranges=loop_ranges,
|
368 | 252 | param_dict=param_dict,
|
369 | 253 | )
|
370 |
| - kernel_ir = kernel_template.kernel_ir |
371 | 254 |
|
372 |
| - if config.DEBUG_ARRAY_OPT: |
373 |
| - print("kernel_ir dump ", type(kernel_ir)) |
374 |
| - kernel_ir.dump() |
375 |
| - print("loop_body dump ", type(loop_body)) |
376 |
| - _print_body(loop_body) |
377 |
| - |
378 |
| - # rename all variables in kernel_ir afresh |
379 |
| - var_table = get_name_var_table(kernel_ir.blocks) |
380 |
| - new_var_dict = {} |
381 |
| - reserved_names = ( |
382 |
| - [sentinel_name] + list(param_dict.values()) + legal_loop_indices |
| 255 | + kernel_dispatcher: SPIRVKernelDispatcher = kernel( |
| 256 | + kernel_template.py_func, |
| 257 | + _parfor_body_args=ParforBodyArguments( |
| 258 | + loop_body=loop_body, |
| 259 | + param_dict=param_dict, |
| 260 | + legal_loop_indices=legal_loop_indices, |
| 261 | + ), |
383 | 262 | )
|
384 |
| - for name, var in var_table.items(): |
385 |
| - if not (name in reserved_names): |
386 |
| - new_var_dict[name] = mk_unique_var(name) |
387 |
| - replace_var_names(kernel_ir.blocks, new_var_dict) |
388 |
| - if config.DEBUG_ARRAY_OPT: |
389 |
| - print("kernel_ir dump after renaming ") |
390 |
| - kernel_ir.dump() |
391 |
| - |
392 |
| - kernel_param_types = param_types |
393 | 263 |
|
394 |
| - if config.DEBUG_ARRAY_OPT: |
395 |
| - print( |
396 |
| - "kernel_param_types = ", |
397 |
| - type(kernel_param_types), |
398 |
| - "\n", |
399 |
| - kernel_param_types, |
400 |
| - ) |
401 |
| - |
402 |
| - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 |
403 |
| - |
404 |
| - # Add kernel stub last label to each parfor.loop_body label to prevent |
405 |
| - # label conflicts. |
406 |
| - loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) |
407 |
| - |
408 |
| - _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body) |
409 |
| - |
410 |
| - if config.DEBUG_ARRAY_OPT: |
411 |
| - print("kernel_ir last dump before renaming") |
412 |
| - kernel_ir.dump() |
413 |
| - |
414 |
| - kernel_ir.blocks = rename_labels(kernel_ir.blocks) |
415 |
| - remove_dels(kernel_ir.blocks) |
416 |
| - |
417 |
| - old_alias = flags.noalias |
418 |
| - if not has_aliases: |
419 |
| - if config.DEBUG_ARRAY_OPT: |
420 |
| - print("No aliases found so adding noalias flag.") |
421 |
| - flags.noalias = True |
422 |
| - |
423 |
| - remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap) |
424 |
| - |
425 |
| - if config.DEBUG_ARRAY_OPT: |
426 |
| - print("kernel_ir after remove dead") |
427 |
| - kernel_ir.dump() |
428 |
| - |
429 |
| - # The first argument to a range kernel is a kernel_api.Item object. The |
430 |
| - # ``Item`` object is used by the kernel_api.spirv backend to generate the |
| 264 | + # The first argument to a range kernel is a kernel_api.NdItem object. The |
| 265 | + # ``NdItem`` object is used by the kernel_api.spirv backend to generate the |
431 | 266 | # correct SPIR-V indexing instructions. Since, the argument is not something
|
432 | 267 | # available originally in the kernel_param_types, we add it at this point to
|
433 | 268 | # make sure the kernel signature matches the actual generated code.
|
434 | 269 | ty_item = ItemType(parfor_dim)
|
435 |
| - kernel_param_types = (ty_item, *kernel_param_types) |
| 270 | + kernel_param_types = (ty_item, *param_types) |
436 | 271 | kernel_sig = signature(types.none, *kernel_param_types)
|
437 | 272 |
|
438 |
| - if config.DEBUG_ARRAY_OPT: |
439 |
| - sys.stdout.flush() |
440 |
| - |
441 |
| - if config.DEBUG_ARRAY_OPT: |
442 |
| - print("after DUFunc inline".center(80, "-")) |
443 |
| - kernel_ir.dump() |
444 |
| - |
445 |
| - # The ParforLegalizeCFD pass has already ensured that the LHS and RHS |
446 |
| - # arrays are on same device. We can take the queue from the first input |
447 |
| - # array and use that to compile the kernel. |
448 |
| - |
449 |
| - exec_queue: dpctl.SyclQueue = None |
450 |
| - |
451 |
| - for arg in parfor_args: |
452 |
| - obj = typemap[arg] |
453 |
| - if isinstance(obj, DpnpNdArray): |
454 |
| - filter_string = obj.queue.sycl_device |
455 |
| - # FIXME: A better design is required so that we do not have to |
456 |
| - # create a queue every time. |
457 |
| - exec_queue = dpctl.get_device_cached_queue(filter_string) |
458 |
| - |
459 |
| - if not exec_queue: |
460 |
| - raise AssertionError( |
461 |
| - "No execution found for parfor. No way to compile the kernel!" |
462 |
| - ) |
463 |
| - |
464 |
| - sycl_kernel = _compile_kernel_parfor( |
465 |
| - exec_queue, |
466 |
| - kernel_name, |
467 |
| - kernel_ir, |
468 |
| - kernel_param_types, |
469 |
| - debug=flags.debuginfo, |
| 273 | + kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( |
| 274 | + types.void(*kernel_param_types) # kernel signature |
470 | 275 | )
|
471 |
| - |
472 |
| - flags.noalias = old_alias |
| 276 | + kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module |
473 | 277 |
|
474 | 278 | if config.DEBUG_ARRAY_OPT:
|
475 | 279 | print("kernel_sig = ", kernel_sig)
|
476 | 280 |
|
477 | 281 | return ParforKernel(
|
478 |
| - name=kernel_name, |
479 |
| - kernel=sycl_kernel, |
480 | 282 | signature=kernel_sig,
|
481 | 283 | kernel_args=parfor_args,
|
482 | 284 | kernel_arg_types=func_arg_types,
|
483 |
| - queue=exec_queue, |
| 285 | + kernel_module=kernel_module, |
484 | 286 | )
|
485 | 287 |
|
486 | 288 |
|
|
0 commit comments