30
30
"QueueRefPayload" , ["queue_ref" , "py_dpctl_sycl_queue_addr" , "pyapi" ]
31
31
)
32
32
33
+ _ArgTyAndValue = namedtuple ("ArgTyAndValue" , ["numba_ty" , "llvmir_val" ])
34
+
33
35
34
36
# XXX: The function should be moved into DpexTargetContext
35
37
def make_queue (context , builder , py_dpctl_sycl_queue ):
@@ -82,37 +84,39 @@ def make_queue(context, builder, py_dpctl_sycl_queue):
82
84
83
85
84
86
def _get_queue_ref (
85
- context , builder , sig , args , * , sycl_queue_arg_pos , array_arg_pos = None
87
+ context ,
88
+ builder ,
89
+ returned_sycl_queue_ty ,
90
+ sycl_queue_arg : _ArgTyAndValue ,
91
+ array_arg : _ArgTyAndValue = None ,
86
92
):
87
93
"""Returns an LLVM IR Value pointer to a DpctlSyclQueueRef
88
94
89
95
The _get_queue_ref function is used by the intinsic functions that implement
90
96
the overloads for dpnp array constructors: ``empty``, ``empty_like``,
91
97
``zeros``, ``zeros_like``, ``ones``, ``ones_like``, ``full``, ``full_like``.
92
98
93
- The args contains the list of LLVM IR values passed in to the dpnp
94
- overloads. The convention we follow is that the queue arg is always the
95
- penultimate arg passed to the intrinsic. For that reason, we can extract the
96
- queue argument as args[-2] and the type of the argument from the signature
97
- as sig.args[-2].
98
-
99
- Depending on whether the ``sycl_queue`` argument was explicitly specified,
100
- or was omitted, the queue_arg will be either a DpctlSyclQueue type or a
101
- numba NoneType/Omitted type. If a DpctlSyclQueue, then we directly extract
102
- the queue_ref from the unboxed native struct representation of a
103
- dpctl.SyclQueue. If a queue was not explicitly provided and the type is
104
- NoneType/Omitted, we get a cached dpctl.SyclQueue from dpctl and unbox it
105
- on the fly and return the queue_ref.
99
+ The function returns an LLVM IR Value corresponding to a dpctl.SyclQueue
100
+ Python object's underlying ``_queue_ref`` pointer. If a non-None
101
+ ``sycl_queue_arg`` is provided, then the ``_queue_ref`` attribute is
102
+ extracted from the ``sycl_queue_arg``. If the ``sycl_queue_arg`` is
103
+ None or omitted and an ``array_arg`` is provided, then the ``_queue_ref``
104
+ is extracted from the unboxed representation of the ``array_arg``. If
105
+ nether a non-None ``sycl_queue_arg`` nor an ``array_arg`` is provided,
106
+ then a cached dpctl.SyclQueue is retreived from dpctl and unboxed on the fly
107
+ and the ``_queue_ref`` from that unboxed queue is returned to caller.
106
108
107
109
Args:
108
110
context (numba.core.base.BaseContext): Any of the context
109
111
derived from Numba's BaseContext
110
112
(e.g. `numba.core.cpu.CPUContext`).
111
113
builder (llvmlite.ir.builder.IRBuilder): The IR builder
112
114
from `llvmlite` for code generation.
113
- sig: Signature of the overload function
114
- args (list): LLVM IR values corresponding to the args passed to the LLVM
115
- function created for a dpnp overload.
115
+ returned_sycl_queue_ty: An instance of numba_dpex.types.DpctlSyclQueue
116
+ sycl_queue_arg: A 2-tuple storing the numba inferred type and the
117
+ corresponding LLVM IR value for a dpctl.SyclQueue Python object.
118
+ array_arg: A 2-tuple storing the numba inferred type and the
119
+ corresponding LLVM IR value for a dpnp.ndarray Python object.
116
120
117
121
Return:
118
122
A namedtuple wrapping the queue_ref pointer, an optional address to
@@ -121,39 +125,39 @@ def _get_queue_ref(
121
125
122
126
"""
123
127
124
- queue_arg = args [sycl_queue_arg_pos ]
125
- queue_arg_ty = sig .args [sycl_queue_arg_pos ]
126
-
127
128
queue_ref = None
128
129
py_dpctl_sycl_queue_addr = None
129
130
pyapi = None
130
131
131
132
if not isinstance (
132
- queue_arg_ty , (types .misc .NoneType , types .misc .Omitted )
133
- ) and isinstance (queue_arg_ty , DpctlSyclQueue ):
134
- if not isinstance (queue_arg .type , llvmir .LiteralStructType ):
133
+ sycl_queue_arg .numba_ty , (types .misc .NoneType , types .misc .Omitted )
134
+ ) and isinstance (sycl_queue_arg .numba_ty , DpctlSyclQueue ):
135
+ if not isinstance (
136
+ sycl_queue_arg .llvmir_val .type , llvmir .LiteralStructType
137
+ ):
135
138
raise AssertionError (
136
139
"Expected the queue_arg to be an llvmir.LiteralStructType"
137
140
)
138
- sycl_queue_dm = dpex_dmm .lookup (queue_arg_ty )
141
+ sycl_queue_dm = dpex_dmm .lookup (sycl_queue_arg . numba_ty )
139
142
queue_ref = builder .extract_value (
140
- queue_arg , sycl_queue_dm .get_field_position ("queue_ref" )
143
+ sycl_queue_arg .llvmir_val ,
144
+ sycl_queue_dm .get_field_position ("queue_ref" ),
141
145
)
142
- elif array_arg_pos is not None :
143
- array_arg = args [array_arg_pos ]
144
- array_arg_ty = sig .args [array_arg_pos ]
145
- dpnp_ndarray_dm = dpex_dmm .lookup (array_arg_ty )
146
+ elif array_arg is not None :
147
+ dpnp_ndarray_dm = dpex_dmm .lookup (array_arg .numba_ty )
146
148
queue_ref = builder .extract_value (
147
- array_arg , dpnp_ndarray_dm .get_field_position ("sycl_queue" )
149
+ array_arg .llvmir_val ,
150
+ dpnp_ndarray_dm .get_field_position ("sycl_queue" ),
148
151
)
149
152
else :
150
- if not isinstance (queue_arg .type , llvmir .PointerType ):
153
+ if not isinstance (sycl_queue_arg . llvmir_val .type , llvmir .PointerType ):
151
154
# TODO: check if the pointer is null
152
155
raise AssertionError (
153
156
"Expected the queue_arg to be an llvmir.PointerType"
154
157
)
155
- ty_sycl_queue = sig .return_type .queue
156
- py_dpctl_sycl_queue = get_device_cached_queue (ty_sycl_queue .sycl_device )
158
+ py_dpctl_sycl_queue = get_device_cached_queue (
159
+ returned_sycl_queue_ty .sycl_device
160
+ )
157
161
(queue_ref , py_dpctl_sycl_queue_addr , pyapi ) = make_queue (
158
162
context , builder , py_dpctl_sycl_queue
159
163
)
@@ -467,8 +471,14 @@ def impl_dpnp_empty(
467
471
sycl_queue_arg_pos = - 2
468
472
469
473
def codegen (context , builder , sig , args ):
474
+ sycl_queue_arg = _ArgTyAndValue (
475
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
476
+ )
470
477
qref_payload : _QueueRefPayload = _get_queue_ref (
471
- context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
478
+ context = context ,
479
+ builder = builder ,
480
+ returned_sycl_queue_ty = sig .return_type .queue ,
481
+ sycl_queue_arg = sycl_queue_arg ,
472
482
)
473
483
474
484
ary = alloc_empty_arrayobj (
@@ -533,8 +543,14 @@ def impl_dpnp_zeros(
533
543
sycl_queue_arg_pos = - 2
534
544
535
545
def codegen (context , builder , sig , args ):
546
+ sycl_queue_arg = _ArgTyAndValue (
547
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
548
+ )
536
549
qref_payload : _QueueRefPayload = _get_queue_ref (
537
- context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
550
+ context = context ,
551
+ builder = builder ,
552
+ returned_sycl_queue_ty = sig .return_type .queue ,
553
+ sycl_queue_arg = sycl_queue_arg ,
538
554
)
539
555
ary = alloc_empty_arrayobj (
540
556
context , builder , sig , qref_payload .queue_ref , args
@@ -607,8 +623,14 @@ def impl_dpnp_ones(
607
623
sycl_queue_arg_pos = - 2
608
624
609
625
def codegen (context , builder , sig , args ):
626
+ sycl_queue_arg = _ArgTyAndValue (
627
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
628
+ )
610
629
qref_payload : _QueueRefPayload = _get_queue_ref (
611
- context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
630
+ context = context ,
631
+ builder = builder ,
632
+ returned_sycl_queue_ty = sig .return_type .queue ,
633
+ sycl_queue_arg = sycl_queue_arg ,
612
634
)
613
635
ary = alloc_empty_arrayobj (
614
636
context , builder , sig , qref_payload .queue_ref , args
@@ -687,8 +709,14 @@ def impl_dpnp_full(
687
709
sycl_queue_arg_pos = - 2
688
710
689
711
def codegen (context , builder , sig , args ):
712
+ sycl_queue_arg = _ArgTyAndValue (
713
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
714
+ )
690
715
qref_payload : _QueueRefPayload = _get_queue_ref (
691
- context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
716
+ context = context ,
717
+ builder = builder ,
718
+ returned_sycl_queue_ty = sig .return_type .queue ,
719
+ sycl_queue_arg = sycl_queue_arg ,
692
720
)
693
721
ary = alloc_empty_arrayobj (
694
722
context , builder , sig , qref_payload .queue_ref , args
@@ -768,13 +796,16 @@ def impl_dpnp_empty_like(
768
796
array_arg_pos = 0
769
797
770
798
def codegen (context , builder , sig , args ):
799
+ sycl_queue_arg = _ArgTyAndValue (
800
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
801
+ )
802
+ array_arg = _ArgTyAndValue (sig .args [array_arg_pos ], args [array_arg_pos ])
771
803
qref_payload : _QueueRefPayload = _get_queue_ref (
772
- context ,
773
- builder ,
774
- sig ,
775
- args ,
776
- sycl_queue_arg_pos = sycl_queue_arg_pos ,
777
- array_arg_pos = array_arg_pos ,
804
+ context = context ,
805
+ builder = builder ,
806
+ returned_sycl_queue_ty = sig .return_type .queue ,
807
+ sycl_queue_arg = sycl_queue_arg ,
808
+ array_arg = array_arg ,
778
809
)
779
810
780
811
ary = alloc_empty_arrayobj (
@@ -848,13 +879,16 @@ def impl_dpnp_zeros_like(
848
879
array_arg_pos = 0
849
880
850
881
def codegen (context , builder , sig , args ):
882
+ sycl_queue_arg = _ArgTyAndValue (
883
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
884
+ )
885
+ array_arg = _ArgTyAndValue (sig .args [array_arg_pos ], args [array_arg_pos ])
851
886
qref_payload : _QueueRefPayload = _get_queue_ref (
852
- context ,
853
- builder ,
854
- sig ,
855
- args ,
856
- sycl_queue_arg_pos = sycl_queue_arg_pos ,
857
- array_arg_pos = array_arg_pos ,
887
+ context = context ,
888
+ builder = builder ,
889
+ returned_sycl_queue_ty = sig .return_type .queue ,
890
+ sycl_queue_arg = sycl_queue_arg ,
891
+ array_arg = array_arg ,
858
892
)
859
893
ary = alloc_empty_arrayobj (
860
894
context , builder , sig , qref_payload .queue_ref , args , is_like = True
@@ -934,13 +968,16 @@ def impl_dpnp_ones_like(
934
968
array_arg_pos = 0
935
969
936
970
def codegen (context , builder , sig , args ):
971
+ sycl_queue_arg = _ArgTyAndValue (
972
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
973
+ )
974
+ array_arg = _ArgTyAndValue (sig .args [array_arg_pos ], args [array_arg_pos ])
937
975
qref_payload : _QueueRefPayload = _get_queue_ref (
938
- context ,
939
- builder ,
940
- sig ,
941
- args ,
942
- sycl_queue_arg_pos = sycl_queue_arg_pos ,
943
- array_arg_pos = array_arg_pos ,
976
+ context = context ,
977
+ builder = builder ,
978
+ returned_sycl_queue_ty = sig .return_type .queue ,
979
+ sycl_queue_arg = sycl_queue_arg ,
980
+ array_arg = array_arg ,
944
981
)
945
982
ary = alloc_empty_arrayobj (
946
983
context , builder , sig , qref_payload .queue_ref , args , is_like = True
@@ -1024,13 +1061,16 @@ def impl_dpnp_full_like(
1024
1061
array_arg_pos = 0
1025
1062
1026
1063
def codegen (context , builder , sig , args ):
1064
+ sycl_queue_arg = _ArgTyAndValue (
1065
+ sig .args [sycl_queue_arg_pos ], args [sycl_queue_arg_pos ]
1066
+ )
1067
+ array_arg = _ArgTyAndValue (sig .args [array_arg_pos ], args [array_arg_pos ])
1027
1068
qref_payload : _QueueRefPayload = _get_queue_ref (
1028
- context ,
1029
- builder ,
1030
- sig ,
1031
- args ,
1032
- sycl_queue_arg_pos = sycl_queue_arg_pos ,
1033
- array_arg_pos = array_arg_pos ,
1069
+ context = context ,
1070
+ builder = builder ,
1071
+ returned_sycl_queue_ty = sig .return_type .queue ,
1072
+ sycl_queue_arg = sycl_queue_arg ,
1073
+ array_arg = array_arg ,
1034
1074
)
1035
1075
ary = alloc_empty_arrayobj (
1036
1076
context , builder , sig , qref_payload .queue_ref , args , is_like = True
0 commit comments