@@ -253,10 +253,11 @@ def _build_complex_arg(
253
253
254
254
def _build_array_arg ( # pylint: disable=too-many-arguments
255
255
self ,
256
- array_val ,
257
- array_data_model ,
258
- arg_list ,
259
- args_ty_list ,
256
+ host_array_val , # llvm_val
257
+ host_array_data_model ,
258
+ kernel_array_data_model ,
259
+ arg_list , # filling this val
260
+ args_ty_list , # filling this val
260
261
arg_num ,
261
262
):
262
263
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
@@ -265,82 +266,59 @@ def _build_array_arg( # pylint: disable=too-many-arguments
265
266
The steps performed here are the same as in
266
267
numba_dpex.core.kernel_interface.arg_pack_unpacker._unpack_array_helper
267
268
"""
268
- # Argument 1: Null pointer for the NRT_MemInfo attribute of the array
269
- nullptr = self ._build_nullptr ()
270
- self ._build_arg (
271
- val = nullptr ,
272
- typ = types .int64 ,
273
- arg_list = arg_list ,
274
- args_ty_list = args_ty_list ,
275
- arg_num = arg_num ,
276
- )
277
- arg_num += 1
278
- # Argument 2: Null pointer for the Parent attribute of the array
279
- nullptr = self ._build_nullptr ()
280
- self ._build_arg (
281
- val = nullptr ,
282
- typ = types .int64 ,
283
- arg_list = arg_list ,
284
- args_ty_list = args_ty_list ,
285
- arg_num = arg_num ,
286
- )
287
- arg_num += 1
269
+ # It might be quite confusing, but we are referring to field position
270
+ # on host device to cast it then to device type and send to device.vars
271
+ # Keeping that in mind answers when to use what data_model.
272
+ # Proper code structure refactoring needed to make it easier to read and
273
+ # maintain.
288
274
# Argument nitems
289
275
self ._build_array_attr_arg (
290
- array_val = array_val ,
291
- array_attr_pos = array_data_model .get_field_position ("nitems" ),
292
- array_attr_ty = array_data_model .get_member_fe_type ("nitems" ),
276
+ array_val = host_array_val ,
277
+ array_attr_pos = host_array_data_model .get_field_position ("nitems" ),
278
+ array_attr_ty = kernel_array_data_model .get_member_fe_type ("nitems" ),
293
279
arg_list = arg_list ,
294
280
args_ty_list = args_ty_list ,
295
281
arg_num = arg_num ,
296
282
)
297
283
arg_num += 1
298
284
# Argument itemsize
299
285
self ._build_array_attr_arg (
300
- array_val = array_val ,
301
- array_attr_pos = array_data_model .get_field_position ("itemsize" ),
302
- array_attr_ty = array_data_model .get_member_fe_type ("itemsize" ),
286
+ array_val = host_array_val ,
287
+ array_attr_pos = host_array_data_model .get_field_position ("itemsize" ),
288
+ array_attr_ty = kernel_array_data_model .get_member_fe_type (
289
+ "itemsize"
290
+ ),
303
291
arg_list = arg_list ,
304
292
args_ty_list = args_ty_list ,
305
293
arg_num = arg_num ,
306
294
)
307
295
arg_num += 1
308
296
# Argument data
309
297
self ._build_array_attr_arg (
310
- array_val = array_val ,
311
- array_attr_pos = array_data_model .get_field_position ("data" ),
312
- array_attr_ty = array_data_model .get_member_fe_type ("data" ),
313
- arg_list = arg_list ,
314
- args_ty_list = args_ty_list ,
315
- arg_num = arg_num ,
316
- )
317
- arg_num += 1
318
- # Argument sycl_queue: as the queue pointer is not to be used in a
319
- # kernel we always pass in a nullptr
320
- self ._build_arg (
321
- val = nullptr ,
322
- typ = types .int64 ,
298
+ array_val = host_array_val ,
299
+ array_attr_pos = host_array_data_model .get_field_position ("data" ),
300
+ array_attr_ty = kernel_array_data_model .get_member_fe_type ("data" ),
323
301
arg_list = arg_list ,
324
302
args_ty_list = args_ty_list ,
325
303
arg_num = arg_num ,
326
304
)
327
305
arg_num += 1
328
306
# Arguments for shape
329
- shape_member = array_data_model .get_member_fe_type ("shape" )
307
+ shape_member = kernel_array_data_model .get_member_fe_type ("shape" )
330
308
self ._build_unituple_member_arg (
331
- array_val = array_val ,
332
- array_attr_pos = array_data_model .get_field_position ("shape" ),
309
+ array_val = host_array_val ,
310
+ array_attr_pos = host_array_data_model .get_field_position ("shape" ),
333
311
ndims = shape_member .count ,
334
312
arg_list = arg_list ,
335
313
args_ty_list = args_ty_list ,
336
314
arg_num = arg_num ,
337
315
)
338
316
arg_num += shape_member .count
339
317
# Arguments for strides
340
- stride_member = array_data_model .get_member_fe_type ("strides" )
318
+ stride_member = kernel_array_data_model .get_member_fe_type ("strides" )
341
319
self ._build_unituple_member_arg (
342
- array_val = array_val ,
343
- array_attr_pos = array_data_model .get_field_position ("strides" ),
320
+ array_val = host_array_val ,
321
+ array_attr_pos = host_array_data_model .get_field_position ("strides" ),
344
322
ndims = stride_member .count ,
345
323
arg_list = arg_list ,
346
324
args_ty_list = args_ty_list ,
@@ -647,10 +625,10 @@ def set_arguments(
647
625
648
626
# Populate the args_list and the args_ty_list LLVM arrays
649
627
self ._populate_kernel_args_and_args_ty_arrays (
650
- callargs_ptrs = kernel_args_ptrs ,
651
- kernel_argtys = ty_kernel_args ,
652
- args_list = args_list ,
653
- args_ty_list = args_ty_list ,
628
+ host_callargs_ptrs = kernel_args_ptrs ,
629
+ host_kernel_argtys = ty_kernel_args ,
630
+ kernel_args_list = args_list ,
631
+ kernel_args_ty_list = args_ty_list ,
654
632
)
655
633
656
634
self .arguments .arg_list = args_list
@@ -796,8 +774,8 @@ def _get_num_flattened_kernel_args(
796
774
self ,
797
775
kernel_argtys : tuple [types .Type , ...],
798
776
) -> int :
799
- """Returns number of flattened arguments based on the numba types.
800
- flattens dpnp arrays and complex values."""
777
+ """Returns number of flattened arguments of kernel data model based on
778
+ the numba types. Flattens usm arrays and complex values."""
801
779
num_flattened_kernel_args = 0
802
780
for arg_type in kernel_argtys :
803
781
if isinstance (arg_type , USMNdArray ):
@@ -812,49 +790,51 @@ def _get_num_flattened_kernel_args(
812
790
813
791
def _populate_kernel_args_and_args_ty_arrays (
814
792
self ,
815
- kernel_argtys ,
816
- callargs_ptrs ,
817
- args_list ,
818
- args_ty_list ,
793
+ host_kernel_argtys ,
794
+ host_callargs_ptrs ,
795
+ kernel_args_list ,
796
+ kernel_args_ty_list ,
819
797
):
820
798
kernel_arg_num = 0
821
- for arg_num , argtype in enumerate (kernel_argtys ):
822
- llvm_val = callargs_ptrs [arg_num ]
799
+ for arg_num , argtype in enumerate (host_kernel_argtys ):
800
+ host_llvm_val = host_callargs_ptrs [arg_num ]
823
801
if isinstance (argtype , USMNdArray ):
824
- datamodel = self .kernel_dmm .lookup (argtype )
802
+ kernel_datamodel = self .kernel_dmm .lookup (argtype )
803
+ host_datamodel = self .context .data_model_manager .lookup (argtype )
825
804
self ._build_array_arg (
826
- array_val = llvm_val ,
827
- array_data_model = datamodel ,
828
- arg_list = args_list ,
829
- args_ty_list = args_ty_list ,
805
+ host_array_val = host_llvm_val ,
806
+ host_array_data_model = host_datamodel ,
807
+ kernel_array_data_model = kernel_datamodel ,
808
+ arg_list = kernel_args_list ,
809
+ args_ty_list = kernel_args_ty_list ,
830
810
arg_num = kernel_arg_num ,
831
811
)
832
- kernel_arg_num += datamodel .flattened_field_count
812
+ kernel_arg_num += kernel_datamodel .flattened_field_count
833
813
else :
834
814
if argtype == types .complex64 :
835
815
self ._build_complex_arg (
836
- llvm_val ,
816
+ host_llvm_val ,
837
817
types .float32 ,
838
- args_list ,
839
- args_ty_list ,
818
+ kernel_args_list ,
819
+ kernel_args_ty_list ,
840
820
kernel_arg_num ,
841
821
)
842
822
kernel_arg_num += 2
843
823
elif argtype == types .complex128 :
844
824
self ._build_complex_arg (
845
- llvm_val ,
825
+ host_llvm_val ,
846
826
types .float64 ,
847
- args_list ,
848
- args_ty_list ,
827
+ kernel_args_list ,
828
+ kernel_args_ty_list ,
849
829
kernel_arg_num ,
850
830
)
851
831
kernel_arg_num += 2
852
832
else :
853
833
self ._build_arg (
854
- llvm_val ,
834
+ host_llvm_val ,
855
835
argtype ,
856
- args_list ,
857
- args_ty_list ,
836
+ kernel_args_list ,
837
+ kernel_args_ty_list ,
858
838
kernel_arg_num ,
859
839
)
860
840
kernel_arg_num += 1
0 commit comments