Skip to content

Commit 6b76fd3

Browse files
author
Diptorup Deb
authored
Merge pull request #1249 from IntelPython/feature/dependant_events
Add dependent events to async call
2 parents 72d8a32 + 3cfa80b commit 6b76fd3

File tree

6 files changed

+214
-104
lines changed

6 files changed

+214
-104
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _submit_parfor_kernel(
185185
kl_builder.set_arguments(
186186
kernel_fn.kernel_arg_types, kernel_args=kernel_args
187187
)
188-
kl_builder.set_dependant_event_list([])
188+
kl_builder.set_dependent_events([])
189189
event_ref = kl_builder.submit()
190190

191191
sycl.dpctl_event_wait(lowerer.builder, event_ref)

numba_dpex/core/runtime/experimental/nrt_reserve_meminfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ extern "C"
2727
* @param QRef Queue reference,
2828
* @param meminfo_array Array of meminfo pointers to perform actions on,
2929
* @param meminfo_array_size Length of meminfo_array,
30-
* @param depERefs Array of dependant events for the host task,
30+
* @param depERefs Array of dependent events for the host task,
3131
* @param nDepERefs Length of depERefs,
3232
* @param status Variable to write status to. Same style as
3333
* dpctl,

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 113 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -374,78 +374,95 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
374374
)
375375
return self.builder.load(sycl_queue_val)
376376

377-
def _allocate_kernel_arg_array(self, num_kernel_args):
378-
"""Allocates an array to store the LLVM Value for every kernel argument.
377+
def _allocate_array(
378+
self, numba_type: types.Type, size: int
379+
) -> llvmir.Instruction:
380+
"""Allocates an LLVM array of given type and size.
379381
380382
Args:
381-
num_kernel_args (int): The number of kernel arguments that
382-
determines the size of args array to allocate.
383+
numba_type: type of the array to allocate,
384+
size: The size of the array to allocate.
383385
384-
Returns: An LLVM IR value pointing to an array to store the kernel
385-
arguments.
386+
Returns: An LLVM IR value pointing to the array.
386387
"""
387-
args_list = cgutils.alloca_once(
388+
return cgutils.alloca_once(
388389
self.builder,
389-
utils.LLVMTypes.byte_ptr_t,
390-
size=self.context.get_constant(types.uintp, num_kernel_args),
390+
self.context.get_value_type(numba_type),
391+
size=self.context.get_constant(types.uintp, size),
391392
)
392393

393-
return args_list
394+
def _populate_array_from_python_list(
395+
self,
396+
numba_type: types.Type,
397+
py_array: list[llvmir.Instruction],
398+
ll_array: llvmir.Instruction,
399+
force_cast: bool = False,
400+
):
401+
"""Populates LLVM values from an input Python list into an LLVM array.
394402
395-
def _allocate_kernel_arg_ty_array(self, num_kernel_args):
396-
"""Allocates an array to store the LLVM Value for the typenum for
397-
every kernel argument.
403+
Args:
404+
numba_type: type of the array to allocate,
405+
py_array: array of llvm ir values to populate.
406+
ll_array: llvm ir value that represents an array to populate,
407+
force_cast: either force cast values to the provided type.
408+
"""
409+
for idx, ll_value in enumerate(py_array):
410+
ll_array_dst = self.builder.gep(
411+
ll_array,
412+
[self.context.get_constant(types.int32, idx)],
413+
)
414+
# bitcast may be extra, but won't hurt,
415+
if force_cast:
416+
ll_value = self.builder.bitcast(
417+
ll_value,
418+
self.context.get_value_type(numba_type),
419+
)
420+
self.builder.store(ll_value, ll_array_dst)
421+
422+
def _create_ll_from_py_list(
423+
self,
424+
numba_type: types.Type,
425+
list_of_ll_values: list[llvmir.Instruction],
426+
force_cast: bool = False,
427+
) -> llvmir.Instruction:
428+
"""Allocates an LLVM IR array of the same size as the input python list
429+
of LLVM IR Values and populates the array with the LLVM Values in the
430+
list.
398431
399432
Args:
400-
num_kernel_args (int): The number of kernel arguments that
401-
determines the size of args array to allocate.
433+
numba_type: type of the array to allocate,
434+
list_of_ll_values: list of LLVM IR values to populate,
435+
force_cast: either force cast values to the provided type.
402436
403-
Returns: An LLVM IR value pointing to an array to store the kernel
404-
arguments typenums as defined in dpctl.
437+
Returns: An LLVM IR value pointing to the array.
405438
"""
406-
args_ty_list = cgutils.alloca_once(
407-
self.builder,
408-
utils.LLVMTypes.int32_t,
409-
size=self.context.get_constant(types.uintp, num_kernel_args),
439+
ll_array = self._allocate_array(numba_type, len(list_of_ll_values))
440+
self._populate_array_from_python_list(
441+
numba_type, list_of_ll_values, ll_array, force_cast
410442
)
411443

412-
return args_ty_list
444+
return ll_array
413445

414446
def _create_sycl_range(self, idx_range):
415-
"""Allocate a size_t[3] array to store the extents of a sycl::range.
447+
"""Allocate an array to store the extents of a sycl::range.
416448
417449
Sycl supports upto 3-dimensional ranges and a such the array is
418450
statically sized to length three. Only the elements that store an actual
419451
range value are populated based on the size of the idx_range argument.
420452
421453
"""
422-
intp_t = utils.get_llvm_type(context=self.context, type=types.intp)
423-
intp_ptr_t = utils.get_llvm_ptr_type(intp_t)
424-
num_dim = len(idx_range)
454+
int64_range = [
455+
self.builder.sext(rext, utils.LLVMTypes.int64_t)
456+
if rext.type != utils.LLVMTypes.int64_t
457+
else rext
458+
for rext in idx_range
459+
]
425460

426-
# form the global range
427-
range_list = cgutils.alloca_once(
428-
self.builder,
429-
utils.get_llvm_type(context=self.context, type=types.uintp),
430-
size=self.context.get_constant(types.uintp, MAX_SIZE_OF_SYCL_RANGE),
431-
)
432-
433-
for i in range(num_dim):
434-
rext = idx_range[i]
435-
if rext.type != utils.LLVMTypes.int64_t:
436-
rext = self.builder.sext(rext, utils.LLVMTypes.int64_t)
437-
438-
# we reverse the global range to account for how sycl and opencl
439-
# range differs
440-
self.builder.store(
441-
rext,
442-
self.builder.gep(
443-
range_list,
444-
[self.context.get_constant(types.uintp, (num_dim - 1) - i)],
445-
),
446-
)
461+
# we reverse the global range to account for how sycl and opencl
462+
# range differs
463+
int64_range.reverse()
447464

448-
return self.builder.bitcast(range_list, intp_ptr_t)
465+
return self._create_ll_from_py_list(types.uintp, int64_range)
449466

450467
def set_kernel(self, sycl_kernel_ref: llvmir.Instruction):
451468
"""Sets kernel to the argument list."""
@@ -597,10 +614,14 @@ def set_arguments(
597614
)
598615

599616
# Create LLVM values for the kernel args list and kernel arg types list
600-
args_list = self._allocate_kernel_arg_array(num_flattened_kernel_args)
617+
args_list = self._allocate_array(
618+
types.voidptr,
619+
num_flattened_kernel_args,
620+
)
601621

602-
args_ty_list = self._allocate_kernel_arg_ty_array(
603-
num_flattened_kernel_args
622+
args_ty_list = self._allocate_array(
623+
types.int32,
624+
num_flattened_kernel_args,
604625
)
605626

606627
kernel_args_ptrs = []
@@ -624,20 +645,17 @@ def set_arguments(
624645
types.uintp, num_flattened_kernel_args
625646
)
626647

627-
def _extract_arguments_from_tuple(
648+
def _extract_llvm_values_from_tuple(
628649
self,
629-
ty_kernel_args_tuple: UniTuple,
630-
ll_kernel_args_tuple: llvmir.Instruction,
650+
ll_tuple: llvmir.Instruction,
631651
) -> list[llvmir.Instruction]:
632652
"""Extracts LLVM IR values from llvm tuple into python array."""
633653

634-
kernel_args = []
635-
for pos in range(len(ty_kernel_args_tuple)):
636-
kernel_args.append(
637-
self.builder.extract_value(ll_kernel_args_tuple, pos)
638-
)
654+
llvm_values = []
655+
for pos in range(len(ll_tuple.type)):
656+
llvm_values.append(self.builder.extract_value(ll_tuple, pos))
639657

640-
return kernel_args
658+
return llvm_values
641659

642660
def set_arguments_form_tuple(
643661
self,
@@ -647,27 +665,45 @@ def set_arguments_form_tuple(
647665
"""Sets flattened kernel args, kernel arg types and number of those
648666
arguments to the argument list based on the arguments stored in tuple.
649667
"""
650-
kernel_args = self._extract_arguments_from_tuple(
651-
ty_kernel_args_tuple, ll_kernel_args_tuple
652-
)
668+
kernel_args = self._extract_llvm_values_from_tuple(ll_kernel_args_tuple)
653669
self.set_arguments(ty_kernel_args_tuple, kernel_args)
654670

655-
def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]):
656-
"""Sets dependant events to the argument list."""
657-
if self.arguments.dep_events is not None:
658-
return
671+
def set_dependent_events(self, dep_events: list[llvmir.Instruction]):
672+
"""Sets dependent events to the argument list."""
673+
ll_dep_events = self._create_ll_from_py_list(types.voidptr, dep_events)
674+
self.arguments.dep_events = ll_dep_events
675+
self.arguments.dep_events_len = self.context.get_constant(
676+
types.uintp, len(dep_events)
677+
)
659678

660-
if len(dep_events) > 0:
661-
# TODO: implement for non zero input
662-
raise NotImplementedError
679+
def set_dependent_events_from_tuple(
680+
self,
681+
ty_dependent_events: UniTuple,
682+
ll_dependent_events: llvmir.Instruction,
683+
):
684+
"""Set's dependent events from tuple represented by LLVM IR.
663685
664-
self.arguments.dep_events = self.builder.bitcast(
665-
utils.create_null_ptr(builder=self.builder, context=self.context),
666-
utils.get_llvm_type(context=self.context, type=types.voidptr),
667-
)
668-
self.arguments.dep_events_len = self.context.get_constant(
669-
types.uintp, 0
686+
Args:
687+
ll_dependent_events: tuple of numba's data models.
688+
"""
689+
if len(ty_dependent_events) == 0:
690+
self.set_dependent_events([])
691+
return
692+
693+
ty_event = ty_dependent_events[0]
694+
dm_dependent_events = self._extract_llvm_values_from_tuple(
695+
ll_dependent_events
670696
)
697+
dependent_events = []
698+
for dm_dependent_event in dm_dependent_events:
699+
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
700+
self.context,
701+
self.builder,
702+
value=dm_dependent_event,
703+
)
704+
dependent_events.append(event_struct_proxy.event_ref)
705+
706+
self.set_dependent_events(dependent_events)
671707

672708
def submit(self) -> llvmir.Instruction:
673709
"""Submits kernel by calling sycl.dpctl_queue_submit_range or
@@ -708,22 +744,7 @@ def _allocate_meminfo_array(
708744
)
709745
]
710746

711-
meminfo_list = cgutils.alloca_once(
712-
self.builder,
713-
utils.get_llvm_type(context=self.context, type=types.voidptr),
714-
size=self.context.get_constant(types.uintp, len(meminfos)),
715-
)
716-
717-
for meminfo_num, meminfo in enumerate(meminfos):
718-
meminfo_arg_dst = self.builder.gep(
719-
meminfo_list,
720-
[self.context.get_constant(types.int32, meminfo_num)],
721-
)
722-
meminfo_ptr = self.builder.bitcast(
723-
meminfo,
724-
utils.get_llvm_type(context=self.context, type=types.voidptr),
725-
)
726-
self.builder.store(meminfo_ptr, meminfo_arg_dst)
747+
meminfo_list = self._create_ll_from_py_list(types.voidptr, meminfos)
727748

728749
return len(meminfos), meminfo_list
729750

numba_dpex/dpctl_iface/libsyclinterface_bindings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def dpctl_queue_submit_range(builder: llvmir.IRBuilder, *args):
154154
llvmir.IntType(64),
155155
llvmir.IntType(64).as_pointer(),
156156
llvmir.IntType(64),
157-
cgutils.voidptr_t,
157+
cgutils.voidptr_t.as_pointer(),
158158
llvmir.IntType(64),
159159
],
160160
func_name="DPCTLQueue_SubmitRange",
@@ -195,7 +195,7 @@ def dpctl_queue_submit_ndrange(builder: llvmir.IRBuilder, *args):
195195
llvmir.IntType(64).as_pointer(),
196196
llvmir.IntType(64).as_pointer(),
197197
llvmir.IntType(64),
198-
cgutils.voidptr_t,
198+
cgutils.voidptr_t.as_pointer(),
199199
llvmir.IntType(64),
200200
],
201201
func_name="DPCTLQueue_SubmitNDRange",

0 commit comments

Comments
 (0)