@@ -253,10 +253,11 @@ def _build_complex_arg(
253253
254254 def _build_array_arg ( # pylint: disable=too-many-arguments
255255 self ,
256- array_val ,
257- array_data_model ,
258- arg_list ,
259- args_ty_list ,
256+ host_array_val , # llvm_val
257+ host_array_data_model ,
258+ kernel_array_data_model ,
259+ arg_list , # filling this val
260+ args_ty_list , # filling this val
260261 arg_num ,
261262 ):
262263 """Creates a list of LLVM Values for an unpacked USMNdArray kernel
@@ -265,82 +266,59 @@ def _build_array_arg( # pylint: disable=too-many-arguments
265266 The steps performed here are the same as in
266267 numba_dpex.core.kernel_interface.arg_pack_unpacker._unpack_array_helper
267268 """
268- # Argument 1: Null pointer for the NRT_MemInfo attribute of the array
269- nullptr = self ._build_nullptr ()
270- self ._build_arg (
271- val = nullptr ,
272- typ = types .int64 ,
273- arg_list = arg_list ,
274- args_ty_list = args_ty_list ,
275- arg_num = arg_num ,
276- )
277- arg_num += 1
278- # Argument 2: Null pointer for the Parent attribute of the array
279- nullptr = self ._build_nullptr ()
280- self ._build_arg (
281- val = nullptr ,
282- typ = types .int64 ,
283- arg_list = arg_list ,
284- args_ty_list = args_ty_list ,
285- arg_num = arg_num ,
286- )
287- arg_num += 1
269+ # It might be quite confusing, but we are referring to field position
270+ # on host device to cast it then to device type and send to device.vars
271+ # Keeping that in mind answers when to use what data_model.
272+ # Proper code structure refactoring needed to make it easier to read and
273+ # maintain.
288274 # Argument nitems
289275 self ._build_array_attr_arg (
290- array_val = array_val ,
291- array_attr_pos = array_data_model .get_field_position ("nitems" ),
292- array_attr_ty = array_data_model .get_member_fe_type ("nitems" ),
276+ array_val = host_array_val ,
277+ array_attr_pos = host_array_data_model .get_field_position ("nitems" ),
278+ array_attr_ty = kernel_array_data_model .get_member_fe_type ("nitems" ),
293279 arg_list = arg_list ,
294280 args_ty_list = args_ty_list ,
295281 arg_num = arg_num ,
296282 )
297283 arg_num += 1
298284 # Argument itemsize
299285 self ._build_array_attr_arg (
300- array_val = array_val ,
301- array_attr_pos = array_data_model .get_field_position ("itemsize" ),
302- array_attr_ty = array_data_model .get_member_fe_type ("itemsize" ),
286+ array_val = host_array_val ,
287+ array_attr_pos = host_array_data_model .get_field_position ("itemsize" ),
288+ array_attr_ty = kernel_array_data_model .get_member_fe_type (
289+ "itemsize"
290+ ),
303291 arg_list = arg_list ,
304292 args_ty_list = args_ty_list ,
305293 arg_num = arg_num ,
306294 )
307295 arg_num += 1
308296 # Argument data
309297 self ._build_array_attr_arg (
310- array_val = array_val ,
311- array_attr_pos = array_data_model .get_field_position ("data" ),
312- array_attr_ty = array_data_model .get_member_fe_type ("data" ),
313- arg_list = arg_list ,
314- args_ty_list = args_ty_list ,
315- arg_num = arg_num ,
316- )
317- arg_num += 1
318- # Argument sycl_queue: as the queue pointer is not to be used in a
319- # kernel we always pass in a nullptr
320- self ._build_arg (
321- val = nullptr ,
322- typ = types .int64 ,
298+ array_val = host_array_val ,
299+ array_attr_pos = host_array_data_model .get_field_position ("data" ),
300+ array_attr_ty = kernel_array_data_model .get_member_fe_type ("data" ),
323301 arg_list = arg_list ,
324302 args_ty_list = args_ty_list ,
325303 arg_num = arg_num ,
326304 )
327305 arg_num += 1
328306 # Arguments for shape
329- shape_member = array_data_model .get_member_fe_type ("shape" )
307+ shape_member = kernel_array_data_model .get_member_fe_type ("shape" )
330308 self ._build_unituple_member_arg (
331- array_val = array_val ,
332- array_attr_pos = array_data_model .get_field_position ("shape" ),
309+ array_val = host_array_val ,
310+ array_attr_pos = host_array_data_model .get_field_position ("shape" ),
333311 ndims = shape_member .count ,
334312 arg_list = arg_list ,
335313 args_ty_list = args_ty_list ,
336314 arg_num = arg_num ,
337315 )
338316 arg_num += shape_member .count
339317 # Arguments for strides
340- stride_member = array_data_model .get_member_fe_type ("strides" )
318+ stride_member = kernel_array_data_model .get_member_fe_type ("strides" )
341319 self ._build_unituple_member_arg (
342- array_val = array_val ,
343- array_attr_pos = array_data_model .get_field_position ("strides" ),
320+ array_val = host_array_val ,
321+ array_attr_pos = host_array_data_model .get_field_position ("strides" ),
344322 ndims = stride_member .count ,
345323 arg_list = arg_list ,
346324 args_ty_list = args_ty_list ,
@@ -647,10 +625,10 @@ def set_arguments(
647625
648626 # Populate the args_list and the args_ty_list LLVM arrays
649627 self ._populate_kernel_args_and_args_ty_arrays (
650- callargs_ptrs = kernel_args_ptrs ,
651- kernel_argtys = ty_kernel_args ,
652- args_list = args_list ,
653- args_ty_list = args_ty_list ,
628+ host_callargs_ptrs = kernel_args_ptrs ,
629+ host_kernel_argtys = ty_kernel_args ,
630+ kernel_args_list = args_list ,
631+ kernel_args_ty_list = args_ty_list ,
654632 )
655633
656634 self .arguments .arg_list = args_list
@@ -796,8 +774,8 @@ def _get_num_flattened_kernel_args(
796774 self ,
797775 kernel_argtys : tuple [types .Type , ...],
798776 ) -> int :
799- """Returns number of flattened arguments based on the numba types.
800- flattens dpnp arrays and complex values."""
777+ """Returns number of flattened arguments of kernel data model based on
778+ the numba types. Flattens usm arrays and complex values."""
801779 num_flattened_kernel_args = 0
802780 for arg_type in kernel_argtys :
803781 if isinstance (arg_type , USMNdArray ):
@@ -812,49 +790,51 @@ def _get_num_flattened_kernel_args(
812790
813791 def _populate_kernel_args_and_args_ty_arrays (
814792 self ,
815- kernel_argtys ,
816- callargs_ptrs ,
817- args_list ,
818- args_ty_list ,
793+ host_kernel_argtys ,
794+ host_callargs_ptrs ,
795+ kernel_args_list ,
796+ kernel_args_ty_list ,
819797 ):
820798 kernel_arg_num = 0
821- for arg_num , argtype in enumerate (kernel_argtys ):
822- llvm_val = callargs_ptrs [arg_num ]
799+ for arg_num , argtype in enumerate (host_kernel_argtys ):
800+ host_llvm_val = host_callargs_ptrs [arg_num ]
823801 if isinstance (argtype , USMNdArray ):
824- datamodel = self .kernel_dmm .lookup (argtype )
802+ kernel_datamodel = self .kernel_dmm .lookup (argtype )
803+ host_datamodel = self .context .data_model_manager .lookup (argtype )
825804 self ._build_array_arg (
826- array_val = llvm_val ,
827- array_data_model = datamodel ,
828- arg_list = args_list ,
829- args_ty_list = args_ty_list ,
805+ host_array_val = host_llvm_val ,
806+ host_array_data_model = host_datamodel ,
807+ kernel_array_data_model = kernel_datamodel ,
808+ arg_list = kernel_args_list ,
809+ args_ty_list = kernel_args_ty_list ,
830810 arg_num = kernel_arg_num ,
831811 )
832- kernel_arg_num += datamodel .flattened_field_count
812+ kernel_arg_num += kernel_datamodel .flattened_field_count
833813 else :
834814 if argtype == types .complex64 :
835815 self ._build_complex_arg (
836- llvm_val ,
816+ host_llvm_val ,
837817 types .float32 ,
838- args_list ,
839- args_ty_list ,
818+ kernel_args_list ,
819+ kernel_args_ty_list ,
840820 kernel_arg_num ,
841821 )
842822 kernel_arg_num += 2
843823 elif argtype == types .complex128 :
844824 self ._build_complex_arg (
845- llvm_val ,
825+ host_llvm_val ,
846826 types .float64 ,
847- args_list ,
848- args_ty_list ,
827+ kernel_args_list ,
828+ kernel_args_ty_list ,
849829 kernel_arg_num ,
850830 )
851831 kernel_arg_num += 2
852832 else :
853833 self ._build_arg (
854- llvm_val ,
834+ host_llvm_val ,
855835 argtype ,
856- args_list ,
857- args_ty_list ,
836+ kernel_args_list ,
837+ kernel_args_ty_list ,
858838 kernel_arg_num ,
859839 )
860840 kernel_arg_num += 1
0 commit comments