21
21
populate_array ,
22
22
)
23
23
24
+ from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
24
25
from numba_dpex .core .runtime import context as dpexrt
25
26
from numba_dpex .core .types import DpnpNdArray
26
27
from numba_dpex .core .types .dpctl_types import DpctlSyclQueue
@@ -80,7 +81,9 @@ def make_queue(context, builder, py_dpctl_sycl_queue):
80
81
return ret
81
82
82
83
83
- def _get_queue_ref (context , builder , sig , args ):
84
+ def _get_queue_ref (
85
+ context , builder , sig , args , * , sycl_queue_arg_pos , array_arg_pos = None
86
+ ):
84
87
"""Returns an LLVM IR Value pointer to a DpctlSyclQueueRef
85
88
86
89
The _get_queue_ref function is used by the intinsic functions that implement
@@ -118,25 +121,33 @@ def _get_queue_ref(context, builder, sig, args):
118
121
119
122
"""
120
123
121
- queue_arg = args [- 2 ]
122
- queue_arg_ty = sig .args [- 2 ]
124
+ queue_arg = args [sycl_queue_arg_pos ]
125
+ queue_arg_ty = sig .args [sycl_queue_arg_pos ]
123
126
124
127
queue_ref = None
125
128
py_dpctl_sycl_queue_addr = None
126
129
pyapi = None
127
130
128
- if isinstance (queue_arg_ty , DpctlSyclQueue ):
131
+ if not isinstance (
132
+ queue_arg_ty , (types .misc .NoneType , types .misc .Omitted )
133
+ ) and isinstance (queue_arg_ty , DpctlSyclQueue ):
129
134
if not isinstance (queue_arg .type , llvmir .LiteralStructType ):
130
135
raise AssertionError
131
- queue_ref = builder .extract_value (queue_arg , 1 )
132
-
133
- elif isinstance (queue_arg_ty , types .misc .NoneType ) or isinstance (
134
- queue_arg_ty , types .misc .Omitted
135
- ):
136
+ sycl_queue_dm = dpex_dmm .lookup (queue_arg_ty )
137
+ queue_ref = builder .extract_value (
138
+ queue_arg , sycl_queue_dm .get_field_position ("queue_ref" )
139
+ )
140
+ elif array_arg_pos is not None :
141
+ array_arg = args [array_arg_pos ]
142
+ array_arg_ty = sig .args [array_arg_pos ]
143
+ dpnp_ndarray_dm = dpex_dmm .lookup (array_arg_ty )
144
+ queue_ref = builder .extract_value (
145
+ array_arg , dpnp_ndarray_dm .get_field_position ("sycl_queue" )
146
+ )
147
+ else :
136
148
if not isinstance (queue_arg .type , llvmir .PointerType ):
137
149
# TODO: check if the pointer is null
138
150
raise AssertionError
139
-
140
151
ty_sycl_queue = sig .return_type .queue
141
152
py_dpctl_sycl_queue = get_device_cached_queue (ty_sycl_queue .sycl_device )
142
153
(queue_ref , py_dpctl_sycl_queue_addr , pyapi ) = make_queue (
@@ -147,6 +158,14 @@ def _get_queue_ref(context, builder, sig, args):
147
158
return ret
148
159
149
160
161
+ def _update_queue_attr (array , queue ):
162
+ """Sets the sycl_queue member of an ArrayStruct."""
163
+
164
+ attr = dict (sycl_queue = queue )
165
+ for k , v in attr .items ():
166
+ setattr (array , k , v )
167
+
168
+
150
169
def _empty_nd_impl (context , builder , arrtype , shapes , queue_ref ):
151
170
"""Utility function used for allocating a new array.
152
171
@@ -252,6 +271,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes, queue_ref):
252
271
shape_array = cgutils .pack_array (builder , shapes , ty = intp_t )
253
272
strides_array = cgutils .pack_array (builder , strides , ty = intp_t )
254
273
274
+ _update_queue_attr (ary , queue = queue_ref_copy )
255
275
populate_array (
256
276
ary ,
257
277
data = builder .bitcast (data , datatype .as_pointer ()),
@@ -432,9 +452,11 @@ def impl_dpnp_empty(
432
452
ty_retty_ref ,
433
453
)
434
454
455
+ sycl_queue_arg_pos = - 2
456
+
435
457
def codegen (context , builder , sig , args ):
436
458
qref_payload : _QueueRefPayload = _get_queue_ref (
437
- context , builder , sig , args
459
+ context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
438
460
)
439
461
440
462
ary = alloc_empty_arrayobj (
@@ -496,10 +518,11 @@ def impl_dpnp_zeros(
496
518
ty_sycl_queue ,
497
519
ty_retty_ref ,
498
520
)
521
+ sycl_queue_arg_pos = - 2
499
522
500
523
def codegen (context , builder , sig , args ):
501
524
qref_payload : _QueueRefPayload = _get_queue_ref (
502
- context , builder , sig , args
525
+ context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
503
526
)
504
527
ary = alloc_empty_arrayobj (
505
528
context , builder , sig , qref_payload .queue_ref , args
@@ -569,9 +592,11 @@ def impl_dpnp_ones(
569
592
ty_retty_ref ,
570
593
)
571
594
595
+ sycl_queue_arg_pos = - 2
596
+
572
597
def codegen (context , builder , sig , args ):
573
598
qref_payload : _QueueRefPayload = _get_queue_ref (
574
- context , builder , sig , args
599
+ context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
575
600
)
576
601
ary = alloc_empty_arrayobj (
577
602
context , builder , sig , qref_payload .queue_ref , args
@@ -647,10 +672,11 @@ def impl_dpnp_full(
647
672
ty_sycl_queue ,
648
673
ty_retty_ref ,
649
674
)
675
+ sycl_queue_arg_pos = - 2
650
676
651
677
def codegen (context , builder , sig , args ):
652
678
qref_payload : _QueueRefPayload = _get_queue_ref (
653
- context , builder , sig , args
679
+ context , builder , sig , args , sycl_queue_arg_pos = sycl_queue_arg_pos
654
680
)
655
681
ary = alloc_empty_arrayobj (
656
682
context , builder , sig , qref_payload .queue_ref , args
@@ -726,10 +752,17 @@ def impl_dpnp_empty_like(
726
752
ty_sycl_queue ,
727
753
ty_retty_ref ,
728
754
)
755
+ sycl_queue_arg_pos = - 2
756
+ array_arg_pos = 0
729
757
730
758
def codegen (context , builder , sig , args ):
731
759
qref_payload : _QueueRefPayload = _get_queue_ref (
732
- context , builder , sig , args
760
+ context ,
761
+ builder ,
762
+ sig ,
763
+ args ,
764
+ sycl_queue_arg_pos = sycl_queue_arg_pos ,
765
+ array_arg_pos = array_arg_pos ,
733
766
)
734
767
735
768
ary = alloc_empty_arrayobj (
@@ -799,9 +832,17 @@ def impl_dpnp_zeros_like(
799
832
ty_retty_ref ,
800
833
)
801
834
835
+ sycl_queue_arg_pos = - 2
836
+ array_arg_pos = 0
837
+
802
838
def codegen (context , builder , sig , args ):
803
839
qref_payload : _QueueRefPayload = _get_queue_ref (
804
- context , builder , sig , args
840
+ context ,
841
+ builder ,
842
+ sig ,
843
+ args ,
844
+ sycl_queue_arg_pos = sycl_queue_arg_pos ,
845
+ array_arg_pos = array_arg_pos ,
805
846
)
806
847
ary = alloc_empty_arrayobj (
807
848
context , builder , sig , qref_payload .queue_ref , args , is_like = True
@@ -877,10 +918,17 @@ def impl_dpnp_ones_like(
877
918
ty_sycl_queue ,
878
919
ty_retty_ref ,
879
920
)
921
+ sycl_queue_arg_pos = - 2
922
+ array_arg_pos = 0
880
923
881
924
def codegen (context , builder , sig , args ):
882
925
qref_payload : _QueueRefPayload = _get_queue_ref (
883
- context , builder , sig , args
926
+ context ,
927
+ builder ,
928
+ sig ,
929
+ args ,
930
+ sycl_queue_arg_pos = sycl_queue_arg_pos ,
931
+ array_arg_pos = array_arg_pos ,
884
932
)
885
933
ary = alloc_empty_arrayobj (
886
934
context , builder , sig , qref_payload .queue_ref , args , is_like = True
@@ -960,10 +1008,17 @@ def impl_dpnp_full_like(
960
1008
ty_sycl_queue ,
961
1009
ty_retty_ref ,
962
1010
)
1011
+ sycl_queue_arg_pos = - 2
1012
+ array_arg_pos = 0
963
1013
964
1014
def codegen (context , builder , sig , args ):
965
1015
qref_payload : _QueueRefPayload = _get_queue_ref (
966
- context , builder , sig , args
1016
+ context ,
1017
+ builder ,
1018
+ sig ,
1019
+ args ,
1020
+ sycl_queue_arg_pos = sycl_queue_arg_pos ,
1021
+ array_arg_pos = array_arg_pos ,
967
1022
)
968
1023
ary = alloc_empty_arrayobj (
969
1024
context , builder , sig , qref_payload .queue_ref , args , is_like = True
0 commit comments