5
5
"""Module that contains numba style wrapper around sycl kernel submit."""
6
6
7
7
from dataclasses import dataclass
8
+ from functools import cached_property
9
+ from typing import NamedTuple , Union
8
10
9
11
import dpctl
10
12
from llvmlite import ir as llvmir
11
13
from llvmlite .ir .builder import IRBuilder
12
14
from numba .core import cgutils , types
13
15
from numba .core .cpu import CPUContext
14
16
from numba .core .datamodel import DataModelManager
17
+ from numba .core .types .containers import UniTuple
15
18
16
19
from numba_dpex import config , utils
20
+ from numba_dpex .core .exceptions import UnreachableError
17
21
from numba_dpex .core .runtime .context import DpexRTContext
18
22
from numba_dpex .core .types import DpnpNdArray
23
+ from numba_dpex .core .types .range_types import NdRangeType , RangeType
19
24
from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
20
25
from numba_dpex .dpctl_iface ._helpers import numba_type_to_dpctl_typenum
26
+ from numba_dpex .utils import create_null_ptr
21
27
22
28
MAX_SIZE_OF_SYCL_RANGE = 3
23
29
24
30
31
+ # TODO: probably not best place for it. Should be in kernel_dispatcher once we
32
+ # get merge experimental. Right now it will cause cyclic import
33
+ class SPIRVKernelModule (NamedTuple ):
34
+ """Represents SPIRV binary code and function name in this binary"""
35
+
36
+ kernel_name : str
37
+ kernel_bitcode : bytes
38
+
39
+
25
40
@dataclass
26
41
class _KernelLaunchIRArguments : # pylint: disable=too-many-instance-attributes
27
42
"""List of kernel launch arguments used in sycl.dpctl_queue_submit_range and
@@ -62,6 +77,22 @@ def to_list(self):
62
77
return res
63
78
64
79
80
+ @dataclass
81
+ class _KernelLaunchIRCachedArguments :
82
+ """Arguments that are being used in KernelLaunchIRBuilder that are either
83
+ intermediate structure of the KernelLaunchIRBuilder like llvm IR array
84
+ stored as a python array of llvm IR values or llvm IR values that may be
85
+ used as an input for builder functions.
86
+
87
+ Main goal is to prevent passing same argument during build process several
88
+ times and to avoid passing output of the builder as an argument for another
89
+ build method."""
90
+
91
+ arg_list : list [llvmir .Instruction ] = None
92
+ arg_ty_list : list [types .Type ] = None
93
+ device_event_ref : llvmir .Instruction = None
94
+
95
+
65
96
class KernelLaunchIRBuilder :
66
97
"""
67
98
KernelLaunchIRBuilder(lowerer, cres)
@@ -86,9 +117,16 @@ def __init__(
86
117
"""
87
118
self .context = context
88
119
self .builder = builder
89
- self .rtctx = DpexRTContext (self .context )
90
120
self .arguments = _KernelLaunchIRArguments ()
121
+ self .cached_arguments = _KernelLaunchIRCachedArguments ()
91
122
self .kernel_dmm = kernel_dmm
123
+ self ._cleanups = []
124
+
125
+ @cached_property
126
+ def dpexrt (self ):
127
+ """Dpex runtime context."""
128
+
129
+ return DpexRTContext (self .context )
92
130
93
131
def _build_nullptr (self ):
94
132
"""Builds the LLVM IR to represent a null pointer.
@@ -329,7 +367,7 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
329
367
# Store the queue returned by DPEXRTQueue_CreateFromFilterString in a
330
368
# local variable
331
369
self .builder .store (
332
- self .rtctx .get_queue_from_filter_string (
370
+ self .dpexrt .get_queue_from_filter_string (
333
371
builder = self .builder , device = device
334
372
),
335
373
sycl_queue_val ,
@@ -413,10 +451,76 @@ def set_kernel(self, sycl_kernel_ref: llvmir.Instruction):
413
451
"""Sets kernel to the argument list."""
414
452
self .arguments .sycl_kernel_ref = sycl_kernel_ref
415
453
454
+ def set_kernel_from_spirv (self , kernel_module : SPIRVKernelModule ):
455
+ """Sets kernel to the argument list from the SPIRV bytecode.
456
+
457
+ It pastes bytecode as a constant string and create kernel bundle from it
458
+ using SYCL API. It caches kernel, so it won't be sent to device second
459
+ time.
460
+ """
461
+ # Inserts a global constant byte string in the current LLVM module to
462
+ # store the passed in SPIR-V binary blob.
463
+ queue_ref = self .arguments .sycl_queue_ref
464
+
465
+ kernel_bc_byte_str = self .context .insert_const_bytes (
466
+ self .builder .module ,
467
+ bytes = kernel_module .kernel_bitcode ,
468
+ )
469
+
470
+ kernel_name = self .context .insert_const_string (
471
+ self .builder .module , kernel_module .kernel_name
472
+ )
473
+
474
+ context_ref = sycl .dpctl_queue_get_context (self .builder , queue_ref )
475
+ device_ref = sycl .dpctl_queue_get_device (self .builder , queue_ref )
476
+
477
+ # build_or_get_kernel steals reference to context and device cause it
478
+ # needs to keep them alive for keys.
479
+ kernel_ref = self .dpexrt .build_or_get_kernel (
480
+ self .builder ,
481
+ [
482
+ context_ref ,
483
+ device_ref ,
484
+ llvmir .Constant (
485
+ llvmir .IntType (64 ), hash (kernel_module .kernel_bitcode )
486
+ ),
487
+ kernel_bc_byte_str ,
488
+ llvmir .Constant (
489
+ llvmir .IntType (64 ), len (kernel_module .kernel_bitcode )
490
+ ),
491
+ self .builder .load (create_null_ptr (self .builder , self .context )),
492
+ kernel_name ,
493
+ ],
494
+ )
495
+
496
+ self ._cleanups .append (self ._clean_kernel_ref )
497
+ self .set_kernel (kernel_ref )
498
+
499
+ def _clean_kernel_ref (self ):
500
+ sycl .dpctl_kernel_delete (self .builder , self .arguments .sycl_kernel_ref )
501
+ self .arguments .sycl_kernel_ref = None
502
+
416
503
def set_queue (self , sycl_queue_ref : llvmir .Instruction ):
417
504
"""Sets queue to the argument list."""
418
505
self .arguments .sycl_queue_ref = sycl_queue_ref
419
506
507
+ def set_queue_from_arguments (
508
+ self ,
509
+ ):
510
+ """Sets the sycl queue from the first DpnpNdArray argument provided
511
+ earlier."""
512
+ queue_ref = get_queue_from_llvm_values (
513
+ self .context ,
514
+ self .builder ,
515
+ self .cached_arguments .arg_ty_list ,
516
+ self .cached_arguments .arg_list ,
517
+ )
518
+
519
+ if queue_ref is None :
520
+ raise ValueError ("There are no arguments that contain queue" )
521
+
522
+ self .set_queue (queue_ref )
523
+
420
524
def set_range (
421
525
self ,
422
526
global_range : list ,
@@ -430,10 +534,52 @@ def set_range(
430
534
types .uintp , len (global_range )
431
535
)
432
536
537
+ def set_range_from_indexer (
538
+ self ,
539
+ ty_indexer_arg : Union [RangeType , NdRangeType ],
540
+ ll_index_arg : llvmir .BaseStructType ,
541
+ ):
542
+ """Returns two lists of LLVM IR Values that hold the unboxed extents of
543
+ a Python Range or NdRange object.
544
+ """
545
+ ndim = ty_indexer_arg .ndim
546
+ global_range_extents = []
547
+ local_range_extents = []
548
+ indexer_datamodel = self .context .data_model_manager .lookup (
549
+ ty_indexer_arg
550
+ )
551
+
552
+ if isinstance (ty_indexer_arg , RangeType ):
553
+ for dim_num in range (ndim ):
554
+ dim_pos = indexer_datamodel .get_field_position (
555
+ "dim" + str (dim_num )
556
+ )
557
+ global_range_extents .append (
558
+ self .builder .extract_value (ll_index_arg , dim_pos )
559
+ )
560
+ elif isinstance (ty_indexer_arg , NdRangeType ):
561
+ for dim_num in range (ndim ):
562
+ gdim_pos = indexer_datamodel .get_field_position (
563
+ "gdim" + str (dim_num )
564
+ )
565
+ global_range_extents .append (
566
+ self .builder .extract_value (ll_index_arg , gdim_pos )
567
+ )
568
+ ldim_pos = indexer_datamodel .get_field_position (
569
+ "ldim" + str (dim_num )
570
+ )
571
+ local_range_extents .append (
572
+ self .builder .extract_value (ll_index_arg , ldim_pos )
573
+ )
574
+ else :
575
+ raise UnreachableError
576
+
577
+ self .set_range (global_range_extents , local_range_extents )
578
+
433
579
def set_arguments (
434
580
self ,
435
- ty_kernel_args : list ,
436
- kernel_args : list ,
581
+ ty_kernel_args : list [ types . Type ] ,
582
+ kernel_args : list [ llvmir . Instruction ] ,
437
583
):
438
584
"""Sets flattened kernel args, kernel arg types and number of those
439
585
arguments to the argument list."""
@@ -443,6 +589,9 @@ def set_arguments(
443
589
"DPEX-DEBUG: Populating kernel args and arg type arrays.\n " ,
444
590
)
445
591
592
+ self .cached_arguments .arg_ty_list = ty_kernel_args
593
+ self .cached_arguments .arg_list = kernel_args
594
+
446
595
num_flattened_kernel_args = self ._get_num_flattened_kernel_args (
447
596
kernel_argtys = ty_kernel_args ,
448
597
)
@@ -475,6 +624,34 @@ def set_arguments(
475
624
types .uintp , num_flattened_kernel_args
476
625
)
477
626
627
+ def _extract_arguments_from_tuple (
628
+ self ,
629
+ ty_kernel_args_tuple : UniTuple ,
630
+ ll_kernel_args_tuple : llvmir .Instruction ,
631
+ ) -> list [llvmir .Instruction ]:
632
+ """Extracts LLVM IR values from llvm tuple into python array."""
633
+
634
+ kernel_args = []
635
+ for pos in range (len (ty_kernel_args_tuple )):
636
+ kernel_args .append (
637
+ self .builder .extract_value (ll_kernel_args_tuple , pos )
638
+ )
639
+
640
+ return kernel_args
641
+
642
+ def set_arguments_form_tuple (
643
+ self ,
644
+ ty_kernel_args_tuple : UniTuple ,
645
+ ll_kernel_args_tuple : llvmir .Instruction ,
646
+ ):
647
+ """Sets flattened kernel args, kernel arg types and number of those
648
+ arguments to the argument list based on the arguments stored in tuple.
649
+ """
650
+ kernel_args = self ._extract_arguments_from_tuple (
651
+ ty_kernel_args_tuple , ll_kernel_args_tuple
652
+ )
653
+ self .set_arguments (ty_kernel_args_tuple , kernel_args )
654
+
478
655
def set_dependant_event_list (self , dep_events : list [llvmir .Instruction ]):
479
656
"""Sets dependant events to the argument list."""
480
657
if self .arguments .dep_events is not None :
@@ -499,11 +676,86 @@ def submit(self) -> llvmir.Instruction:
499
676
args = self .arguments .to_list ()
500
677
501
678
if self .arguments .local_range is None :
502
- eref = sycl .dpctl_queue_submit_range (self .builder , * args )
679
+ event_ref = sycl .dpctl_queue_submit_range (self .builder , * args )
503
680
else :
504
- eref = sycl .dpctl_queue_submit_ndrange (self .builder , * args )
681
+ event_ref = sycl .dpctl_queue_submit_ndrange (self .builder , * args )
682
+
683
+ self .cached_arguments .device_event_ref = event_ref
505
684
506
- return eref
685
+ for cleanup in self ._cleanups :
686
+ cleanup ()
687
+
688
+ return event_ref
689
+
690
+ def _allocate_meminfo_array (
691
+ self ,
692
+ ) -> tuple [int , list [llvmir .Instruction ]]:
693
+ """Allocates an LLVM array value to store each memory info from all
694
+ kernel arguments. The array is the populated with the LLVM value for
695
+ every meminfo of the kernel arguments.
696
+ """
697
+ kernel_args = self .cached_arguments .arg_list
698
+ kernel_argtys = self .cached_arguments .arg_ty_list
699
+
700
+ meminfos = []
701
+ for arg_num , argtype in enumerate (kernel_argtys ):
702
+ llvm_val = kernel_args [arg_num ]
703
+
704
+ meminfos += [
705
+ meminfo
706
+ for ty , meminfo in self .context .nrt .get_meminfos (
707
+ self .builder , argtype , llvm_val
708
+ )
709
+ ]
710
+
711
+ meminfo_list = cgutils .alloca_once (
712
+ self .builder ,
713
+ utils .get_llvm_type (context = self .context , type = types .voidptr ),
714
+ size = self .context .get_constant (types .uintp , len (meminfos )),
715
+ )
716
+
717
+ for meminfo_num , meminfo in enumerate (meminfos ):
718
+ meminfo_arg_dst = self .builder .gep (
719
+ meminfo_list ,
720
+ [self .context .get_constant (types .int32 , meminfo_num )],
721
+ )
722
+ meminfo_ptr = self .builder .bitcast (
723
+ meminfo ,
724
+ utils .get_llvm_type (context = self .context , type = types .voidptr ),
725
+ )
726
+ self .builder .store (meminfo_ptr , meminfo_arg_dst )
727
+
728
+ return len (meminfos ), meminfo_list
729
+
730
+ def acquire_meminfo_and_submit_release (
731
+ self ,
732
+ ) -> llvmir .Instruction :
733
+ """Schedule sycl host task to release nrt meminfo of the arguments used
734
+ to run job. Use it to keep arguments alive during kernel execution."""
735
+ queue_ref = self .arguments .sycl_queue_ref
736
+ event_ref = self .cached_arguments .device_event_ref
737
+
738
+ total_meminfos , meminfo_list = self ._allocate_meminfo_array ()
739
+
740
+ event_ref_ptr = self .builder .alloca (event_ref .type )
741
+ self .builder .store (event_ref , event_ref_ptr )
742
+
743
+ status_ptr = cgutils .alloca_once (
744
+ self .builder , self .context .get_value_type (types .uint64 )
745
+ )
746
+ host_eref = self .dpexrt .acquire_meminfo_and_schedule_release (
747
+ self .builder ,
748
+ [
749
+ self .context .nrt .get_nrt_api (self .builder ),
750
+ queue_ref ,
751
+ meminfo_list ,
752
+ self .context .get_constant (types .uintp , total_meminfos ),
753
+ event_ref_ptr ,
754
+ self .context .get_constant (types .uintp , 1 ),
755
+ status_ptr ,
756
+ ],
757
+ )
758
+ return host_eref
507
759
508
760
def _get_num_flattened_kernel_args (
509
761
self ,
@@ -571,3 +823,26 @@ def _populate_kernel_args_and_args_ty_arrays(
571
823
kernel_arg_num ,
572
824
)
573
825
kernel_arg_num += 1
826
+
827
+
828
+ def get_queue_from_llvm_values (
829
+ ctx : CPUContext ,
830
+ builder : IRBuilder ,
831
+ ty_kernel_args : list [types .Type ],
832
+ ll_kernel_args : list [llvmir .Instruction ],
833
+ ):
834
+ """
835
+ Get the sycl queue from the first DpnpNdArray argument. Prior passes
836
+ before lowering make sure that compute-follows-data is enforceable
837
+ for a specific call to a kernel. As such, at the stage of lowering
838
+ the queue from the first DpnpNdArray argument can be extracted.
839
+ """
840
+ for arg_num , argty in enumerate (ty_kernel_args ):
841
+ if isinstance (argty , DpnpNdArray ):
842
+ llvm_val = ll_kernel_args [arg_num ]
843
+ datamodel = ctx .data_model_manager .lookup (argty )
844
+ sycl_queue_attr_pos = datamodel .get_field_position ("sycl_queue" )
845
+ queue_ref = builder .extract_value (llvm_val , sycl_queue_attr_pos )
846
+ break
847
+
848
+ return queue_ref
0 commit comments