@@ -374,78 +374,95 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
374
374
)
375
375
return self .builder .load (sycl_queue_val )
376
376
377
- def _allocate_kernel_arg_array (self , num_kernel_args ):
378
- """Allocates an array to store the LLVM Value for every kernel argument.
377
+ def _allocate_array (
378
+ self , numba_type : types .Type , size : int
379
+ ) -> llvmir .Instruction :
380
+ """Allocates an LLVM array of given type and size.
379
381
380
382
Args:
381
- num_kernel_args (int): The number of kernel arguments that
382
- determines the size of args array to allocate.
383
+ numba_type: type of the array to allocate,
384
+ size: The size of the array to allocate.
383
385
384
- Returns: An LLVM IR value pointing to an array to store the kernel
385
- arguments.
386
+ Returns: An LLVM IR value pointing to the array.
386
387
"""
387
- args_list = cgutils .alloca_once (
388
+ return cgutils .alloca_once (
388
389
self .builder ,
389
- utils . LLVMTypes . byte_ptr_t ,
390
- size = self .context .get_constant (types .uintp , num_kernel_args ),
390
+ self . context . get_value_type ( numba_type ) ,
391
+ size = self .context .get_constant (types .uintp , size ),
391
392
)
392
393
393
- return args_list
394
+ def _populate_array_from_python_list (
395
+ self ,
396
+ numba_type : types .Type ,
397
+ py_array : list [llvmir .Instruction ],
398
+ ll_array : llvmir .Instruction ,
399
+ force_cast : bool = False ,
400
+ ):
401
+ """Populates LLVM values from an input Python list into an LLVM array.
394
402
395
- def _allocate_kernel_arg_ty_array (self , num_kernel_args ):
396
- """Allocates an array to store the LLVM Value for the typenum for
397
- every kernel argument.
403
+ Args:
404
+ numba_type: type of the array to allocate,
405
+ py_array: array of llvm ir values to populate.
406
+ ll_array: llvm ir value that represents an array to populate,
407
+ force_cast: either force cast values to the provided type.
408
+ """
409
+ for idx , ll_value in enumerate (py_array ):
410
+ ll_array_dst = self .builder .gep (
411
+ ll_array ,
412
+ [self .context .get_constant (types .int32 , idx )],
413
+ )
414
+ # bitcast may be extra, but won't hurt,
415
+ if force_cast :
416
+ ll_value = self .builder .bitcast (
417
+ ll_value ,
418
+ self .context .get_value_type (numba_type ),
419
+ )
420
+ self .builder .store (ll_value , ll_array_dst )
421
+
422
+ def _create_ll_from_py_list (
423
+ self ,
424
+ numba_type : types .Type ,
425
+ list_of_ll_values : list [llvmir .Instruction ],
426
+ force_cast : bool = False ,
427
+ ) -> llvmir .Instruction :
428
+ """Allocates an LLVM IR array of the same size as the input python list
429
+ of LLVM IR Values and populates the array with the LLVM Values in the
430
+ list.
398
431
399
432
Args:
400
- num_kernel_args (int): The number of kernel arguments that
401
- determines the size of args array to allocate.
433
+ numba_type: type of the array to allocate,
434
+ list_of_ll_values: list of LLVM IR values to populate,
435
+ force_cast: either force cast values to the provided type.
402
436
403
- Returns: An LLVM IR value pointing to an array to store the kernel
404
- arguments typenums as defined in dpctl.
437
+ Returns: An LLVM IR value pointing to the array.
405
438
"""
406
- args_ty_list = cgutils .alloca_once (
407
- self .builder ,
408
- utils .LLVMTypes .int32_t ,
409
- size = self .context .get_constant (types .uintp , num_kernel_args ),
439
+ ll_array = self ._allocate_array (numba_type , len (list_of_ll_values ))
440
+ self ._populate_array_from_python_list (
441
+ numba_type , list_of_ll_values , ll_array , force_cast
410
442
)
411
443
412
- return args_ty_list
444
+ return ll_array
413
445
414
446
def _create_sycl_range (self , idx_range ):
415
- """Allocate a size_t[3] array to store the extents of a sycl::range.
447
+ """Allocate an array to store the extents of a sycl::range.
416
448
417
449
Sycl supports upto 3-dimensional ranges and a such the array is
418
450
statically sized to length three. Only the elements that store an actual
419
451
range value are populated based on the size of the idx_range argument.
420
452
421
453
"""
422
- intp_t = utils .get_llvm_type (context = self .context , type = types .intp )
423
- intp_ptr_t = utils .get_llvm_ptr_type (intp_t )
424
- num_dim = len (idx_range )
454
+ int64_range = [
455
+ self .builder .sext (rext , utils .LLVMTypes .int64_t )
456
+ if rext .type != utils .LLVMTypes .int64_t
457
+ else rext
458
+ for rext in idx_range
459
+ ]
425
460
426
- # form the global range
427
- range_list = cgutils .alloca_once (
428
- self .builder ,
429
- utils .get_llvm_type (context = self .context , type = types .uintp ),
430
- size = self .context .get_constant (types .uintp , MAX_SIZE_OF_SYCL_RANGE ),
431
- )
432
-
433
- for i in range (num_dim ):
434
- rext = idx_range [i ]
435
- if rext .type != utils .LLVMTypes .int64_t :
436
- rext = self .builder .sext (rext , utils .LLVMTypes .int64_t )
437
-
438
- # we reverse the global range to account for how sycl and opencl
439
- # range differs
440
- self .builder .store (
441
- rext ,
442
- self .builder .gep (
443
- range_list ,
444
- [self .context .get_constant (types .uintp , (num_dim - 1 ) - i )],
445
- ),
446
- )
461
+ # we reverse the global range to account for how sycl and opencl
462
+ # range differs
463
+ int64_range .reverse ()
447
464
448
- return self .builder . bitcast ( range_list , intp_ptr_t )
465
+ return self ._create_ll_from_py_list ( types . uintp , int64_range )
449
466
450
467
def set_kernel (self , sycl_kernel_ref : llvmir .Instruction ):
451
468
"""Sets kernel to the argument list."""
@@ -597,10 +614,14 @@ def set_arguments(
597
614
)
598
615
599
616
# Create LLVM values for the kernel args list and kernel arg types list
600
- args_list = self ._allocate_kernel_arg_array (num_flattened_kernel_args )
617
+ args_list = self ._allocate_array (
618
+ types .voidptr ,
619
+ num_flattened_kernel_args ,
620
+ )
601
621
602
- args_ty_list = self ._allocate_kernel_arg_ty_array (
603
- num_flattened_kernel_args
622
+ args_ty_list = self ._allocate_array (
623
+ types .int32 ,
624
+ num_flattened_kernel_args ,
604
625
)
605
626
606
627
kernel_args_ptrs = []
@@ -624,20 +645,17 @@ def set_arguments(
624
645
types .uintp , num_flattened_kernel_args
625
646
)
626
647
627
- def _extract_arguments_from_tuple (
648
+ def _extract_llvm_values_from_tuple (
628
649
self ,
629
- ty_kernel_args_tuple : UniTuple ,
630
- ll_kernel_args_tuple : llvmir .Instruction ,
650
+ ll_tuple : llvmir .Instruction ,
631
651
) -> list [llvmir .Instruction ]:
632
652
"""Extracts LLVM IR values from llvm tuple into python array."""
633
653
634
- kernel_args = []
635
- for pos in range (len (ty_kernel_args_tuple )):
636
- kernel_args .append (
637
- self .builder .extract_value (ll_kernel_args_tuple , pos )
638
- )
654
+ llvm_values = []
655
+ for pos in range (len (ll_tuple .type )):
656
+ llvm_values .append (self .builder .extract_value (ll_tuple , pos ))
639
657
640
- return kernel_args
658
+ return llvm_values
641
659
642
660
def set_arguments_form_tuple (
643
661
self ,
@@ -647,27 +665,45 @@ def set_arguments_form_tuple(
647
665
"""Sets flattened kernel args, kernel arg types and number of those
648
666
arguments to the argument list based on the arguments stored in tuple.
649
667
"""
650
- kernel_args = self ._extract_arguments_from_tuple (
651
- ty_kernel_args_tuple , ll_kernel_args_tuple
652
- )
668
+ kernel_args = self ._extract_llvm_values_from_tuple (ll_kernel_args_tuple )
653
669
self .set_arguments (ty_kernel_args_tuple , kernel_args )
654
670
655
- def set_dependant_event_list (self , dep_events : list [llvmir .Instruction ]):
656
- """Sets dependant events to the argument list."""
657
- if self .arguments .dep_events is not None :
658
- return
671
+ def set_dependent_events (self , dep_events : list [llvmir .Instruction ]):
672
+ """Sets dependent events to the argument list."""
673
+ ll_dep_events = self ._create_ll_from_py_list (types .voidptr , dep_events )
674
+ self .arguments .dep_events = ll_dep_events
675
+ self .arguments .dep_events_len = self .context .get_constant (
676
+ types .uintp , len (dep_events )
677
+ )
659
678
660
- if len (dep_events ) > 0 :
661
- # TODO: implement for non zero input
662
- raise NotImplementedError
679
+ def set_dependent_events_from_tuple (
680
+ self ,
681
+ ty_dependent_events : UniTuple ,
682
+ ll_dependent_events : llvmir .Instruction ,
683
+ ):
684
+ """Set's dependent events from tuple represented by LLVM IR.
663
685
664
- self .arguments .dep_events = self .builder .bitcast (
665
- utils .create_null_ptr (builder = self .builder , context = self .context ),
666
- utils .get_llvm_type (context = self .context , type = types .voidptr ),
667
- )
668
- self .arguments .dep_events_len = self .context .get_constant (
669
- types .uintp , 0
686
+ Args:
687
+ ll_dependent_events: tuple of numba's data models.
688
+ """
689
+ if len (ty_dependent_events ) == 0 :
690
+ self .set_dependent_events ([])
691
+ return
692
+
693
+ ty_event = ty_dependent_events [0 ]
694
+ dm_dependent_events = self ._extract_llvm_values_from_tuple (
695
+ ll_dependent_events
670
696
)
697
+ dependent_events = []
698
+ for dm_dependent_event in dm_dependent_events :
699
+ event_struct_proxy = cgutils .create_struct_proxy (ty_event )(
700
+ self .context ,
701
+ self .builder ,
702
+ value = dm_dependent_event ,
703
+ )
704
+ dependent_events .append (event_struct_proxy .event_ref )
705
+
706
+ self .set_dependent_events (dependent_events )
671
707
672
708
def submit (self ) -> llvmir .Instruction :
673
709
"""Submits kernel by calling sycl.dpctl_queue_submit_range or
@@ -708,22 +744,7 @@ def _allocate_meminfo_array(
708
744
)
709
745
]
710
746
711
- meminfo_list = cgutils .alloca_once (
712
- self .builder ,
713
- utils .get_llvm_type (context = self .context , type = types .voidptr ),
714
- size = self .context .get_constant (types .uintp , len (meminfos )),
715
- )
716
-
717
- for meminfo_num , meminfo in enumerate (meminfos ):
718
- meminfo_arg_dst = self .builder .gep (
719
- meminfo_list ,
720
- [self .context .get_constant (types .int32 , meminfo_num )],
721
- )
722
- meminfo_ptr = self .builder .bitcast (
723
- meminfo ,
724
- utils .get_llvm_type (context = self .context , type = types .voidptr ),
725
- )
726
- self .builder .store (meminfo_ptr , meminfo_arg_dst )
747
+ meminfo_list = self ._create_ll_from_py_list (types .voidptr , meminfos )
727
748
728
749
return len (meminfos ), meminfo_list
729
750
0 commit comments