Skip to content

Commit 3f8cdf9

Browse files
authored
Merge pull request #1245 from IntelPython/feature/extend_kernel_launcher
Extend kernel launcher
2 parents 2d6ddb8 + ceb920d commit 3f8cdf9

File tree

2 files changed

+368
-325
lines changed

2 files changed

+368
-325
lines changed

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 282 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,38 @@
55
"""Module that contains numba style wrapper around sycl kernel submit."""
66

77
from dataclasses import dataclass
8+
from functools import cached_property
9+
from typing import NamedTuple, Union
810

911
import dpctl
1012
from llvmlite import ir as llvmir
1113
from llvmlite.ir.builder import IRBuilder
1214
from numba.core import cgutils, types
1315
from numba.core.cpu import CPUContext
1416
from numba.core.datamodel import DataModelManager
17+
from numba.core.types.containers import UniTuple
1518

1619
from numba_dpex import config, utils
20+
from numba_dpex.core.exceptions import UnreachableError
1721
from numba_dpex.core.runtime.context import DpexRTContext
1822
from numba_dpex.core.types import DpnpNdArray
23+
from numba_dpex.core.types.range_types import NdRangeType, RangeType
1924
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2025
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
26+
from numba_dpex.utils import create_null_ptr
2127

2228
MAX_SIZE_OF_SYCL_RANGE = 3
2329

2430

31+
# TODO: probably not best place for it. Should be in kernel_dispatcher once we
32+
# get merge experimental. Right now it will cause cyclic import
33+
class SPIRVKernelModule(NamedTuple):
34+
"""Represents SPIRV binary code and function name in this binary"""
35+
36+
kernel_name: str
37+
kernel_bitcode: bytes
38+
39+
2540
@dataclass
2641
class _KernelLaunchIRArguments: # pylint: disable=too-many-instance-attributes
2742
"""List of kernel launch arguments used in sycl.dpctl_queue_submit_range and
@@ -62,6 +77,22 @@ def to_list(self):
6277
return res
6378

6479

80+
@dataclass
81+
class _KernelLaunchIRCachedArguments:
82+
"""Arguments that are being used in KernelLaunchIRBuilder that are either
83+
intermediate structure of the KernelLaunchIRBuilder like llvm IR array
84+
stored as a python array of llvm IR values or llvm IR values that may be
85+
used as an input for builder functions.
86+
87+
Main goal is to prevent passing same argument during build process several
88+
times and to avoid passing output of the builder as an argument for another
89+
build method."""
90+
91+
arg_list: list[llvmir.Instruction] = None
92+
arg_ty_list: list[types.Type] = None
93+
device_event_ref: llvmir.Instruction = None
94+
95+
6596
class KernelLaunchIRBuilder:
6697
"""
6798
KernelLaunchIRBuilder(lowerer, cres)
@@ -86,9 +117,16 @@ def __init__(
86117
"""
87118
self.context = context
88119
self.builder = builder
89-
self.rtctx = DpexRTContext(self.context)
90120
self.arguments = _KernelLaunchIRArguments()
121+
self.cached_arguments = _KernelLaunchIRCachedArguments()
91122
self.kernel_dmm = kernel_dmm
123+
self._cleanups = []
124+
125+
@cached_property
126+
def dpexrt(self):
127+
"""Dpex runtime context."""
128+
129+
return DpexRTContext(self.context)
92130

93131
def _build_nullptr(self):
94132
"""Builds the LLVM IR to represent a null pointer.
@@ -329,7 +367,7 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
329367
# Store the queue returned by DPEXRTQueue_CreateFromFilterString in a
330368
# local variable
331369
self.builder.store(
332-
self.rtctx.get_queue_from_filter_string(
370+
self.dpexrt.get_queue_from_filter_string(
333371
builder=self.builder, device=device
334372
),
335373
sycl_queue_val,
@@ -413,10 +451,76 @@ def set_kernel(self, sycl_kernel_ref: llvmir.Instruction):
413451
"""Sets kernel to the argument list."""
414452
self.arguments.sycl_kernel_ref = sycl_kernel_ref
415453

454+
def set_kernel_from_spirv(self, kernel_module: SPIRVKernelModule):
455+
"""Sets kernel to the argument list from the SPIRV bytecode.
456+
457+
It pastes bytecode as a constant string and create kernel bundle from it
458+
using SYCL API. It caches kernel, so it won't be sent to device second
459+
time.
460+
"""
461+
# Inserts a global constant byte string in the current LLVM module to
462+
# store the passed in SPIR-V binary blob.
463+
queue_ref = self.arguments.sycl_queue_ref
464+
465+
kernel_bc_byte_str = self.context.insert_const_bytes(
466+
self.builder.module,
467+
bytes=kernel_module.kernel_bitcode,
468+
)
469+
470+
kernel_name = self.context.insert_const_string(
471+
self.builder.module, kernel_module.kernel_name
472+
)
473+
474+
context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref)
475+
device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref)
476+
477+
# build_or_get_kernel steals reference to context and device cause it
478+
# needs to keep them alive for keys.
479+
kernel_ref = self.dpexrt.build_or_get_kernel(
480+
self.builder,
481+
[
482+
context_ref,
483+
device_ref,
484+
llvmir.Constant(
485+
llvmir.IntType(64), hash(kernel_module.kernel_bitcode)
486+
),
487+
kernel_bc_byte_str,
488+
llvmir.Constant(
489+
llvmir.IntType(64), len(kernel_module.kernel_bitcode)
490+
),
491+
self.builder.load(create_null_ptr(self.builder, self.context)),
492+
kernel_name,
493+
],
494+
)
495+
496+
self._cleanups.append(self._clean_kernel_ref)
497+
self.set_kernel(kernel_ref)
498+
499+
def _clean_kernel_ref(self):
500+
sycl.dpctl_kernel_delete(self.builder, self.arguments.sycl_kernel_ref)
501+
self.arguments.sycl_kernel_ref = None
502+
416503
def set_queue(self, sycl_queue_ref: llvmir.Instruction):
417504
"""Sets queue to the argument list."""
418505
self.arguments.sycl_queue_ref = sycl_queue_ref
419506

507+
def set_queue_from_arguments(
508+
self,
509+
):
510+
"""Sets the sycl queue from the first DpnpNdArray argument provided
511+
earlier."""
512+
queue_ref = get_queue_from_llvm_values(
513+
self.context,
514+
self.builder,
515+
self.cached_arguments.arg_ty_list,
516+
self.cached_arguments.arg_list,
517+
)
518+
519+
if queue_ref is None:
520+
raise ValueError("There are no arguments that contain queue")
521+
522+
self.set_queue(queue_ref)
523+
420524
def set_range(
421525
self,
422526
global_range: list,
@@ -430,10 +534,52 @@ def set_range(
430534
types.uintp, len(global_range)
431535
)
432536

537+
def set_range_from_indexer(
538+
self,
539+
ty_indexer_arg: Union[RangeType, NdRangeType],
540+
ll_index_arg: llvmir.BaseStructType,
541+
):
542+
"""Returns two lists of LLVM IR Values that hold the unboxed extents of
543+
a Python Range or NdRange object.
544+
"""
545+
ndim = ty_indexer_arg.ndim
546+
global_range_extents = []
547+
local_range_extents = []
548+
indexer_datamodel = self.context.data_model_manager.lookup(
549+
ty_indexer_arg
550+
)
551+
552+
if isinstance(ty_indexer_arg, RangeType):
553+
for dim_num in range(ndim):
554+
dim_pos = indexer_datamodel.get_field_position(
555+
"dim" + str(dim_num)
556+
)
557+
global_range_extents.append(
558+
self.builder.extract_value(ll_index_arg, dim_pos)
559+
)
560+
elif isinstance(ty_indexer_arg, NdRangeType):
561+
for dim_num in range(ndim):
562+
gdim_pos = indexer_datamodel.get_field_position(
563+
"gdim" + str(dim_num)
564+
)
565+
global_range_extents.append(
566+
self.builder.extract_value(ll_index_arg, gdim_pos)
567+
)
568+
ldim_pos = indexer_datamodel.get_field_position(
569+
"ldim" + str(dim_num)
570+
)
571+
local_range_extents.append(
572+
self.builder.extract_value(ll_index_arg, ldim_pos)
573+
)
574+
else:
575+
raise UnreachableError
576+
577+
self.set_range(global_range_extents, local_range_extents)
578+
433579
def set_arguments(
434580
self,
435-
ty_kernel_args: list,
436-
kernel_args: list,
581+
ty_kernel_args: list[types.Type],
582+
kernel_args: list[llvmir.Instruction],
437583
):
438584
"""Sets flattened kernel args, kernel arg types and number of those
439585
arguments to the argument list."""
@@ -443,6 +589,9 @@ def set_arguments(
443589
"DPEX-DEBUG: Populating kernel args and arg type arrays.\n",
444590
)
445591

592+
self.cached_arguments.arg_ty_list = ty_kernel_args
593+
self.cached_arguments.arg_list = kernel_args
594+
446595
num_flattened_kernel_args = self._get_num_flattened_kernel_args(
447596
kernel_argtys=ty_kernel_args,
448597
)
@@ -475,6 +624,34 @@ def set_arguments(
475624
types.uintp, num_flattened_kernel_args
476625
)
477626

627+
def _extract_arguments_from_tuple(
628+
self,
629+
ty_kernel_args_tuple: UniTuple,
630+
ll_kernel_args_tuple: llvmir.Instruction,
631+
) -> list[llvmir.Instruction]:
632+
"""Extracts LLVM IR values from llvm tuple into python array."""
633+
634+
kernel_args = []
635+
for pos in range(len(ty_kernel_args_tuple)):
636+
kernel_args.append(
637+
self.builder.extract_value(ll_kernel_args_tuple, pos)
638+
)
639+
640+
return kernel_args
641+
642+
def set_arguments_form_tuple(
643+
self,
644+
ty_kernel_args_tuple: UniTuple,
645+
ll_kernel_args_tuple: llvmir.Instruction,
646+
):
647+
"""Sets flattened kernel args, kernel arg types and number of those
648+
arguments to the argument list based on the arguments stored in tuple.
649+
"""
650+
kernel_args = self._extract_arguments_from_tuple(
651+
ty_kernel_args_tuple, ll_kernel_args_tuple
652+
)
653+
self.set_arguments(ty_kernel_args_tuple, kernel_args)
654+
478655
def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]):
479656
"""Sets dependant events to the argument list."""
480657
if self.arguments.dep_events is not None:
@@ -499,11 +676,86 @@ def submit(self) -> llvmir.Instruction:
499676
args = self.arguments.to_list()
500677

501678
if self.arguments.local_range is None:
502-
eref = sycl.dpctl_queue_submit_range(self.builder, *args)
679+
event_ref = sycl.dpctl_queue_submit_range(self.builder, *args)
503680
else:
504-
eref = sycl.dpctl_queue_submit_ndrange(self.builder, *args)
681+
event_ref = sycl.dpctl_queue_submit_ndrange(self.builder, *args)
682+
683+
self.cached_arguments.device_event_ref = event_ref
505684

506-
return eref
685+
for cleanup in self._cleanups:
686+
cleanup()
687+
688+
return event_ref
689+
690+
def _allocate_meminfo_array(
691+
self,
692+
) -> tuple[int, list[llvmir.Instruction]]:
693+
"""Allocates an LLVM array value to store each memory info from all
694+
kernel arguments. The array is the populated with the LLVM value for
695+
every meminfo of the kernel arguments.
696+
"""
697+
kernel_args = self.cached_arguments.arg_list
698+
kernel_argtys = self.cached_arguments.arg_ty_list
699+
700+
meminfos = []
701+
for arg_num, argtype in enumerate(kernel_argtys):
702+
llvm_val = kernel_args[arg_num]
703+
704+
meminfos += [
705+
meminfo
706+
for ty, meminfo in self.context.nrt.get_meminfos(
707+
self.builder, argtype, llvm_val
708+
)
709+
]
710+
711+
meminfo_list = cgutils.alloca_once(
712+
self.builder,
713+
utils.get_llvm_type(context=self.context, type=types.voidptr),
714+
size=self.context.get_constant(types.uintp, len(meminfos)),
715+
)
716+
717+
for meminfo_num, meminfo in enumerate(meminfos):
718+
meminfo_arg_dst = self.builder.gep(
719+
meminfo_list,
720+
[self.context.get_constant(types.int32, meminfo_num)],
721+
)
722+
meminfo_ptr = self.builder.bitcast(
723+
meminfo,
724+
utils.get_llvm_type(context=self.context, type=types.voidptr),
725+
)
726+
self.builder.store(meminfo_ptr, meminfo_arg_dst)
727+
728+
return len(meminfos), meminfo_list
729+
730+
def acquire_meminfo_and_submit_release(
731+
self,
732+
) -> llvmir.Instruction:
733+
"""Schedule sycl host task to release nrt meminfo of the arguments used
734+
to run job. Use it to keep arguments alive during kernel execution."""
735+
queue_ref = self.arguments.sycl_queue_ref
736+
event_ref = self.cached_arguments.device_event_ref
737+
738+
total_meminfos, meminfo_list = self._allocate_meminfo_array()
739+
740+
event_ref_ptr = self.builder.alloca(event_ref.type)
741+
self.builder.store(event_ref, event_ref_ptr)
742+
743+
status_ptr = cgutils.alloca_once(
744+
self.builder, self.context.get_value_type(types.uint64)
745+
)
746+
host_eref = self.dpexrt.acquire_meminfo_and_schedule_release(
747+
self.builder,
748+
[
749+
self.context.nrt.get_nrt_api(self.builder),
750+
queue_ref,
751+
meminfo_list,
752+
self.context.get_constant(types.uintp, total_meminfos),
753+
event_ref_ptr,
754+
self.context.get_constant(types.uintp, 1),
755+
status_ptr,
756+
],
757+
)
758+
return host_eref
507759

508760
def _get_num_flattened_kernel_args(
509761
self,
@@ -571,3 +823,26 @@ def _populate_kernel_args_and_args_ty_arrays(
571823
kernel_arg_num,
572824
)
573825
kernel_arg_num += 1
826+
827+
828+
def get_queue_from_llvm_values(
829+
ctx: CPUContext,
830+
builder: IRBuilder,
831+
ty_kernel_args: list[types.Type],
832+
ll_kernel_args: list[llvmir.Instruction],
833+
):
834+
"""
835+
Get the sycl queue from the first DpnpNdArray argument. Prior passes
836+
before lowering make sure that compute-follows-data is enforceable
837+
for a specific call to a kernel. As such, at the stage of lowering
838+
the queue from the first DpnpNdArray argument can be extracted.
839+
"""
840+
for arg_num, argty in enumerate(ty_kernel_args):
841+
if isinstance(argty, DpnpNdArray):
842+
llvm_val = ll_kernel_args[arg_num]
843+
datamodel = ctx.data_model_manager.lookup(argty)
844+
sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue")
845+
queue_ref = builder.extract_value(llvm_val, sycl_queue_attr_pos)
846+
break
847+
848+
return queue_ref

0 commit comments

Comments
 (0)