Skip to content

Commit 69cb43e

Browse files
committed
Remove parent,meminfo,sycl_queue from kernel data model
1 parent 8c3a7d5 commit 69cb43e

File tree

3 files changed

+56
-88
lines changed

3 files changed

+56
-88
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,12 @@ class USMArrayDeviceModel(StructModel):
8080
def __init__(self, dmm, fe_type):
8181
ndim = fe_type.ndim
8282
members = [
83-
# meminfo never used in kernel, so we don'te care about addrspace
84-
("meminfo", types.MemInfoPointer(fe_type.dtype)),
85-
# parent never used in kernel, so we don'te care about addrspace
86-
("parent", types.pyobject),
8783
("nitems", types.intp),
8884
("itemsize", types.intp),
8985
(
9086
"data",
9187
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
9288
),
93-
# sycl_queue never used in kernel, so we don'te care about addrspace
94-
("sycl_queue", types.voidptr),
9589
("shape", types.UniTuple(types.intp, ndim)),
9690
("strides", types.UniTuple(types.intp, ndim)),
9791
]

numba_dpex/core/kernel_interface/arg_pack_unpacker.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,9 @@ def _unpack_usm_array(self, val):
4747
strides = suai_attrs.strides
4848
ndim = suai_attrs.dimensions
4949

50-
# meminfo
51-
unpacked_array_attrs.append(ctypes.c_size_t(0))
52-
# parent
53-
unpacked_array_attrs.append(ctypes.c_size_t(0))
5450
unpacked_array_attrs.append(ctypes.c_longlong(size))
5551
unpacked_array_attrs.append(ctypes.c_longlong(itemsize))
5652
unpacked_array_attrs.append(buf)
57-
# queue: unused and passed as void*
58-
unpacked_array_attrs.append(ctypes.c_size_t(0))
5953
for ax in range(ndim):
6054
unpacked_array_attrs.append(ctypes.c_longlong(shape[ax]))
6155
for ax in range(ndim):

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 56 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,11 @@ def _build_complex_arg(
253253

254254
def _build_array_arg( # pylint: disable=too-many-arguments
255255
self,
256-
array_val,
257-
array_data_model,
258-
arg_list,
259-
args_ty_list,
256+
host_array_val, # llvm_val
257+
host_array_data_model,
258+
kernel_array_data_model,
259+
arg_list, # filling this val
260+
args_ty_list, # filling this val
260261
arg_num,
261262
):
262263
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
@@ -265,82 +266,59 @@ def _build_array_arg( # pylint: disable=too-many-arguments
265266
The steps performed here are the same as in
266267
numba_dpex.core.kernel_interface.arg_pack_unpacker._unpack_array_helper
267268
"""
268-
# Argument 1: Null pointer for the NRT_MemInfo attribute of the array
269-
nullptr = self._build_nullptr()
270-
self._build_arg(
271-
val=nullptr,
272-
typ=types.int64,
273-
arg_list=arg_list,
274-
args_ty_list=args_ty_list,
275-
arg_num=arg_num,
276-
)
277-
arg_num += 1
278-
# Argument 2: Null pointer for the Parent attribute of the array
279-
nullptr = self._build_nullptr()
280-
self._build_arg(
281-
val=nullptr,
282-
typ=types.int64,
283-
arg_list=arg_list,
284-
args_ty_list=args_ty_list,
285-
arg_num=arg_num,
286-
)
287-
arg_num += 1
269+
# It might be quite confusing, but we are referring to field position
270+
# on host device to cast it then to device type and send to device.vars
271+
# Keeping that in mind answers when to use what data_model.
272+
# Proper code structure refactoring needed to make it easier to read and
273+
# maintain.
288274
# Argument nitems
289275
self._build_array_attr_arg(
290-
array_val=array_val,
291-
array_attr_pos=array_data_model.get_field_position("nitems"),
292-
array_attr_ty=array_data_model.get_member_fe_type("nitems"),
276+
array_val=host_array_val,
277+
array_attr_pos=host_array_data_model.get_field_position("nitems"),
278+
array_attr_ty=kernel_array_data_model.get_member_fe_type("nitems"),
293279
arg_list=arg_list,
294280
args_ty_list=args_ty_list,
295281
arg_num=arg_num,
296282
)
297283
arg_num += 1
298284
# Argument itemsize
299285
self._build_array_attr_arg(
300-
array_val=array_val,
301-
array_attr_pos=array_data_model.get_field_position("itemsize"),
302-
array_attr_ty=array_data_model.get_member_fe_type("itemsize"),
286+
array_val=host_array_val,
287+
array_attr_pos=host_array_data_model.get_field_position("itemsize"),
288+
array_attr_ty=kernel_array_data_model.get_member_fe_type(
289+
"itemsize"
290+
),
303291
arg_list=arg_list,
304292
args_ty_list=args_ty_list,
305293
arg_num=arg_num,
306294
)
307295
arg_num += 1
308296
# Argument data
309297
self._build_array_attr_arg(
310-
array_val=array_val,
311-
array_attr_pos=array_data_model.get_field_position("data"),
312-
array_attr_ty=array_data_model.get_member_fe_type("data"),
313-
arg_list=arg_list,
314-
args_ty_list=args_ty_list,
315-
arg_num=arg_num,
316-
)
317-
arg_num += 1
318-
# Argument sycl_queue: as the queue pointer is not to be used in a
319-
# kernel we always pass in a nullptr
320-
self._build_arg(
321-
val=nullptr,
322-
typ=types.int64,
298+
array_val=host_array_val,
299+
array_attr_pos=host_array_data_model.get_field_position("data"),
300+
array_attr_ty=kernel_array_data_model.get_member_fe_type("data"),
323301
arg_list=arg_list,
324302
args_ty_list=args_ty_list,
325303
arg_num=arg_num,
326304
)
327305
arg_num += 1
328306
# Arguments for shape
329-
shape_member = array_data_model.get_member_fe_type("shape")
307+
shape_member = kernel_array_data_model.get_member_fe_type("shape")
330308
self._build_unituple_member_arg(
331-
array_val=array_val,
332-
array_attr_pos=array_data_model.get_field_position("shape"),
309+
array_val=host_array_val,
310+
array_attr_pos=host_array_data_model.get_field_position("shape"),
333311
ndims=shape_member.count,
334312
arg_list=arg_list,
335313
args_ty_list=args_ty_list,
336314
arg_num=arg_num,
337315
)
338316
arg_num += shape_member.count
339317
# Arguments for strides
340-
stride_member = array_data_model.get_member_fe_type("strides")
318+
stride_member = kernel_array_data_model.get_member_fe_type("strides")
341319
self._build_unituple_member_arg(
342-
array_val=array_val,
343-
array_attr_pos=array_data_model.get_field_position("strides"),
320+
array_val=host_array_val,
321+
array_attr_pos=host_array_data_model.get_field_position("strides"),
344322
ndims=stride_member.count,
345323
arg_list=arg_list,
346324
args_ty_list=args_ty_list,
@@ -647,10 +625,10 @@ def set_arguments(
647625

648626
# Populate the args_list and the args_ty_list LLVM arrays
649627
self._populate_kernel_args_and_args_ty_arrays(
650-
callargs_ptrs=kernel_args_ptrs,
651-
kernel_argtys=ty_kernel_args,
652-
args_list=args_list,
653-
args_ty_list=args_ty_list,
628+
host_callargs_ptrs=kernel_args_ptrs,
629+
host_kernel_argtys=ty_kernel_args,
630+
kernel_args_list=args_list,
631+
kernel_args_ty_list=args_ty_list,
654632
)
655633

656634
self.arguments.arg_list = args_list
@@ -796,8 +774,8 @@ def _get_num_flattened_kernel_args(
796774
self,
797775
kernel_argtys: tuple[types.Type, ...],
798776
) -> int:
799-
"""Returns number of flattened arguments based on the numba types.
800-
flattens dpnp arrays and complex values."""
777+
"""Returns number of flattened arguments of kernel data model based on
778+
the numba types. Flattens usm arrays and complex values."""
801779
num_flattened_kernel_args = 0
802780
for arg_type in kernel_argtys:
803781
if isinstance(arg_type, USMNdArray):
@@ -812,49 +790,51 @@ def _get_num_flattened_kernel_args(
812790

813791
def _populate_kernel_args_and_args_ty_arrays(
814792
self,
815-
kernel_argtys,
816-
callargs_ptrs,
817-
args_list,
818-
args_ty_list,
793+
host_kernel_argtys,
794+
host_callargs_ptrs,
795+
kernel_args_list,
796+
kernel_args_ty_list,
819797
):
820798
kernel_arg_num = 0
821-
for arg_num, argtype in enumerate(kernel_argtys):
822-
llvm_val = callargs_ptrs[arg_num]
799+
for arg_num, argtype in enumerate(host_kernel_argtys):
800+
host_llvm_val = host_callargs_ptrs[arg_num]
823801
if isinstance(argtype, USMNdArray):
824-
datamodel = self.kernel_dmm.lookup(argtype)
802+
kernel_datamodel = self.kernel_dmm.lookup(argtype)
803+
host_datamodel = self.context.data_model_manager.lookup(argtype)
825804
self._build_array_arg(
826-
array_val=llvm_val,
827-
array_data_model=datamodel,
828-
arg_list=args_list,
829-
args_ty_list=args_ty_list,
805+
host_array_val=host_llvm_val,
806+
host_array_data_model=host_datamodel,
807+
kernel_array_data_model=kernel_datamodel,
808+
arg_list=kernel_args_list,
809+
args_ty_list=kernel_args_ty_list,
830810
arg_num=kernel_arg_num,
831811
)
832-
kernel_arg_num += datamodel.flattened_field_count
812+
kernel_arg_num += kernel_datamodel.flattened_field_count
833813
else:
834814
if argtype == types.complex64:
835815
self._build_complex_arg(
836-
llvm_val,
816+
host_llvm_val,
837817
types.float32,
838-
args_list,
839-
args_ty_list,
818+
kernel_args_list,
819+
kernel_args_ty_list,
840820
kernel_arg_num,
841821
)
842822
kernel_arg_num += 2
843823
elif argtype == types.complex128:
844824
self._build_complex_arg(
845-
llvm_val,
825+
host_llvm_val,
846826
types.float64,
847-
args_list,
848-
args_ty_list,
827+
kernel_args_list,
828+
kernel_args_ty_list,
849829
kernel_arg_num,
850830
)
851831
kernel_arg_num += 2
852832
else:
853833
self._build_arg(
854-
llvm_val,
834+
host_llvm_val,
855835
argtype,
856-
args_list,
857-
args_ty_list,
836+
kernel_args_list,
837+
kernel_args_ty_list,
858838
kernel_arg_num,
859839
)
860840
kernel_arg_num += 1

0 commit comments

Comments
 (0)