Skip to content

Commit 070df7d

Browse files
committed
Remove array allocation duplicates in Kernel Launcher
1 parent 61c6429 commit 070df7d

File tree

1 file changed

+80
-79
lines changed

1 file changed

+80
-79
lines changed

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -374,78 +374,95 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
374374
)
375375
return self.builder.load(sycl_queue_val)
376376

377-
def _allocate_kernel_arg_array(self, num_kernel_args):
378-
"""Allocates an array to store the LLVM Value for every kernel argument.
377+
def _allocate_array(
378+
self, numba_type: types.Type, size: int
379+
) -> llvmir.Instruction:
380+
"""Allocates an LLVM array of given type and size.
379381
380382
Args:
381-
num_kernel_args (int): The number of kernel arguments that
382-
determines the size of args array to allocate.
383+
numba_type: type of the array to allocate,
384+
size: The size of the array to allocate.
383385
384-
Returns: An LLVM IR value pointing to an array to store the kernel
385-
arguments.
386+
Returns: An LLVM IR value pointing to the array.
386387
"""
387-
args_list = cgutils.alloca_once(
388+
return cgutils.alloca_once(
388389
self.builder,
389-
utils.LLVMTypes.byte_ptr_t,
390-
size=self.context.get_constant(types.uintp, num_kernel_args),
390+
self.context.get_value_type(numba_type),
391+
size=self.context.get_constant(types.uintp, size),
391392
)
392393

393-
return args_list
394+
def _populate_array_from_python_list(
395+
self,
396+
numba_type: types.Type,
397+
py_array: list[llvmir.Instruction],
398+
ll_array: llvmir.Instruction,
399+
force_cast: bool = False,
400+
):
401+
"""Populates LLVM values from an input Python list into an LLVM array.
402+
403+
Args:
404+
numba_type: type of the array to allocate,
405+
py_array: array of llvm ir values to populate.
406+
ll_array: llvm ir value that represents an array to populate,
407+
force_cast: either force cast values to the provided type.
408+
"""
409+
for idx, ll_value in enumerate(py_array):
410+
ll_array_dst = self.builder.gep(
411+
ll_array,
412+
[self.context.get_constant(types.int32, idx)],
413+
)
414+
# bitcast may be extra, but won't hurt,
415+
if force_cast:
416+
ll_value = self.builder.bitcast(
417+
ll_value,
418+
self.context.get_value_type(numba_type),
419+
)
420+
self.builder.store(ll_value, ll_array_dst)
394421

395-
def _allocate_kernel_arg_ty_array(self, num_kernel_args):
396-
"""Allocates an array to store the LLVM Value for the typenum for
397-
every kernel argument.
422+
def _create_ll_from_py_list(
423+
self,
424+
numba_type: types.Type,
425+
list_of_ll_values: list[llvmir.Instruction],
426+
force_cast: bool = False,
427+
) -> llvmir.Instruction:
428+
"""Allocates an LLVM IR array of the same size as the input python list
429+
of LLVM IR Values and populates the array with the LLVM Values in the
430+
list.
398431
399432
Args:
400-
num_kernel_args (int): The number of kernel arguments that
401-
determines the size of args array to allocate.
433+
numba_type: type of the array to allocate,
434+
list_of_ll_values: list of LLVM IR values to populate,
435+
force_cast: either force cast values to the provided type.
402436
403-
Returns: An LLVM IR value pointing to an array to store the kernel
404-
arguments typenums as defined in dpctl.
437+
Returns: An LLVM IR value pointing to the array.
405438
"""
406-
args_ty_list = cgutils.alloca_once(
407-
self.builder,
408-
utils.LLVMTypes.int32_t,
409-
size=self.context.get_constant(types.uintp, num_kernel_args),
439+
ll_array = self._allocate_array(numba_type, len(list_of_ll_values))
440+
self._populate_array_from_python_list(
441+
numba_type, list_of_ll_values, ll_array, force_cast
410442
)
411443

412-
return args_ty_list
444+
return ll_array
413445

414446
def _create_sycl_range(self, idx_range):
415-
"""Allocate a size_t[3] array to store the extents of a sycl::range.
447+
"""Allocate an array to store the extents of a sycl::range.
416448
417449
Sycl supports upto 3-dimensional ranges and a such the array is
418450
statically sized to length three. Only the elements that store an actual
419451
range value are populated based on the size of the idx_range argument.
420452
421453
"""
422-
intp_t = utils.get_llvm_type(context=self.context, type=types.intp)
423-
intp_ptr_t = utils.get_llvm_ptr_type(intp_t)
424-
num_dim = len(idx_range)
454+
int64_range = [
455+
self.builder.sext(rext, utils.LLVMTypes.int64_t)
456+
if rext.type != utils.LLVMTypes.int64_t
457+
else rext
458+
for rext in idx_range
459+
]
425460

426-
# form the global range
427-
range_list = cgutils.alloca_once(
428-
self.builder,
429-
utils.get_llvm_type(context=self.context, type=types.uintp),
430-
size=self.context.get_constant(types.uintp, MAX_SIZE_OF_SYCL_RANGE),
431-
)
432-
433-
for i in range(num_dim):
434-
rext = idx_range[i]
435-
if rext.type != utils.LLVMTypes.int64_t:
436-
rext = self.builder.sext(rext, utils.LLVMTypes.int64_t)
437-
438-
# we reverse the global range to account for how sycl and opencl
439-
# range differs
440-
self.builder.store(
441-
rext,
442-
self.builder.gep(
443-
range_list,
444-
[self.context.get_constant(types.uintp, (num_dim - 1) - i)],
445-
),
446-
)
461+
# we reverse the global range to account for how sycl and opencl
462+
# range differs
463+
int64_range.reverse()
447464

448-
return self.builder.bitcast(range_list, intp_ptr_t)
465+
return self._create_ll_from_py_list(types.uintp, int64_range)
449466

450467
def set_kernel(self, sycl_kernel_ref: llvmir.Instruction):
451468
"""Sets kernel to the argument list."""
@@ -597,10 +614,14 @@ def set_arguments(
597614
)
598615

599616
# Create LLVM values for the kernel args list and kernel arg types list
600-
args_list = self._allocate_kernel_arg_array(num_flattened_kernel_args)
617+
args_list = self._allocate_array(
618+
types.voidptr,
619+
num_flattened_kernel_args,
620+
)
601621

602-
args_ty_list = self._allocate_kernel_arg_ty_array(
603-
num_flattened_kernel_args
622+
args_ty_list = self._allocate_array(
623+
types.int32,
624+
num_flattened_kernel_args,
604625
)
605626

606627
kernel_args_ptrs = []
@@ -624,20 +645,17 @@ def set_arguments(
624645
types.uintp, num_flattened_kernel_args
625646
)
626647

627-
def _extract_arguments_from_tuple(
648+
def _extract_llvm_values_from_tuple(
628649
self,
629-
ty_kernel_args_tuple: UniTuple,
630-
ll_kernel_args_tuple: llvmir.Instruction,
650+
ll_tuple: llvmir.Instruction,
631651
) -> list[llvmir.Instruction]:
632652
"""Extracts LLVM IR values from llvm tuple into python array."""
633653

634-
kernel_args = []
635-
for pos in range(len(ty_kernel_args_tuple)):
636-
kernel_args.append(
637-
self.builder.extract_value(ll_kernel_args_tuple, pos)
638-
)
654+
llvm_values = []
655+
for pos in range(len(ll_tuple.type)):
656+
llvm_values.append(self.builder.extract_value(ll_tuple, pos))
639657

640-
return kernel_args
658+
return llvm_values
641659

642660
def set_arguments_form_tuple(
643661
self,
@@ -647,9 +665,7 @@ def set_arguments_form_tuple(
647665
"""Sets flattened kernel args, kernel arg types and number of those
648666
arguments to the argument list based on the arguments stored in tuple.
649667
"""
650-
kernel_args = self._extract_arguments_from_tuple(
651-
ty_kernel_args_tuple, ll_kernel_args_tuple
652-
)
668+
kernel_args = self._extract_llvm_values_from_tuple(ll_kernel_args_tuple)
653669
self.set_arguments(ty_kernel_args_tuple, kernel_args)
654670

655671
def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]):
@@ -708,22 +724,7 @@ def _allocate_meminfo_array(
708724
)
709725
]
710726

711-
meminfo_list = cgutils.alloca_once(
712-
self.builder,
713-
utils.get_llvm_type(context=self.context, type=types.voidptr),
714-
size=self.context.get_constant(types.uintp, len(meminfos)),
715-
)
716-
717-
for meminfo_num, meminfo in enumerate(meminfos):
718-
meminfo_arg_dst = self.builder.gep(
719-
meminfo_list,
720-
[self.context.get_constant(types.int32, meminfo_num)],
721-
)
722-
meminfo_ptr = self.builder.bitcast(
723-
meminfo,
724-
utils.get_llvm_type(context=self.context, type=types.voidptr),
725-
)
726-
self.builder.store(meminfo_ptr, meminfo_arg_dst)
727+
meminfo_list = self._create_ll_from_py_list(types.voidptr, meminfos)
727728

728729
return len(meminfos), meminfo_list
729730

0 commit comments

Comments
 (0)