Skip to content

Commit 7a5a317

Browse files
author
Diptorup Deb
committed
Updates the dpnp array constructor overloads.
- Updates the *_like overloads to extract the sycl_queue for the input array. It was not possible previously as the sycl_queue attribute was not present. - Update unit tests. - Add new unit tests.
1 parent 3d394f5 commit 7a5a317

File tree

8 files changed

+243
-49
lines changed

8 files changed

+243
-49
lines changed

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
populate_array,
2222
)
2323

24+
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
2425
from numba_dpex.core.runtime import context as dpexrt
2526
from numba_dpex.core.types import DpnpNdArray
2627
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
@@ -80,7 +81,9 @@ def make_queue(context, builder, py_dpctl_sycl_queue):
8081
return ret
8182

8283

83-
def _get_queue_ref(context, builder, sig, args):
84+
def _get_queue_ref(
85+
context, builder, sig, args, *, sycl_queue_arg_pos, array_arg_pos=None
86+
):
8487
"""Returns an LLVM IR Value pointer to a DpctlSyclQueueRef
8588
8689
The _get_queue_ref function is used by the intinsic functions that implement
@@ -118,25 +121,33 @@ def _get_queue_ref(context, builder, sig, args):
118121
119122
"""
120123

121-
queue_arg = args[-2]
122-
queue_arg_ty = sig.args[-2]
124+
queue_arg = args[sycl_queue_arg_pos]
125+
queue_arg_ty = sig.args[sycl_queue_arg_pos]
123126

124127
queue_ref = None
125128
py_dpctl_sycl_queue_addr = None
126129
pyapi = None
127130

128-
if isinstance(queue_arg_ty, DpctlSyclQueue):
131+
if not isinstance(
132+
queue_arg_ty, (types.misc.NoneType, types.misc.Omitted)
133+
) and isinstance(queue_arg_ty, DpctlSyclQueue):
129134
if not isinstance(queue_arg.type, llvmir.LiteralStructType):
130135
raise AssertionError
131-
queue_ref = builder.extract_value(queue_arg, 1)
132-
133-
elif isinstance(queue_arg_ty, types.misc.NoneType) or isinstance(
134-
queue_arg_ty, types.misc.Omitted
135-
):
136+
sycl_queue_dm = dpex_dmm.lookup(queue_arg_ty)
137+
queue_ref = builder.extract_value(
138+
queue_arg, sycl_queue_dm.get_field_position("queue_ref")
139+
)
140+
elif array_arg_pos is not None:
141+
array_arg = args[array_arg_pos]
142+
array_arg_ty = sig.args[array_arg_pos]
143+
dpnp_ndarray_dm = dpex_dmm.lookup(array_arg_ty)
144+
queue_ref = builder.extract_value(
145+
array_arg, dpnp_ndarray_dm.get_field_position("sycl_queue")
146+
)
147+
else:
136148
if not isinstance(queue_arg.type, llvmir.PointerType):
137149
# TODO: check if the pointer is null
138150
raise AssertionError
139-
140151
ty_sycl_queue = sig.return_type.queue
141152
py_dpctl_sycl_queue = get_device_cached_queue(ty_sycl_queue.sycl_device)
142153
(queue_ref, py_dpctl_sycl_queue_addr, pyapi) = make_queue(
@@ -147,6 +158,14 @@ def _get_queue_ref(context, builder, sig, args):
147158
return ret
148159

149160

161+
def _update_queue_attr(array, queue):
162+
"""Sets the sycl_queue member of an ArrayStruct."""
163+
164+
attr = dict(sycl_queue=queue)
165+
for k, v in attr.items():
166+
setattr(array, k, v)
167+
168+
150169
def _empty_nd_impl(context, builder, arrtype, shapes, queue_ref):
151170
"""Utility function used for allocating a new array.
152171
@@ -252,6 +271,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes, queue_ref):
252271
shape_array = cgutils.pack_array(builder, shapes, ty=intp_t)
253272
strides_array = cgutils.pack_array(builder, strides, ty=intp_t)
254273

274+
_update_queue_attr(ary, queue=queue_ref_copy)
255275
populate_array(
256276
ary,
257277
data=builder.bitcast(data, datatype.as_pointer()),
@@ -432,9 +452,11 @@ def impl_dpnp_empty(
432452
ty_retty_ref,
433453
)
434454

455+
sycl_queue_arg_pos = -2
456+
435457
def codegen(context, builder, sig, args):
436458
qref_payload: _QueueRefPayload = _get_queue_ref(
437-
context, builder, sig, args
459+
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
438460
)
439461

440462
ary = alloc_empty_arrayobj(
@@ -496,10 +518,11 @@ def impl_dpnp_zeros(
496518
ty_sycl_queue,
497519
ty_retty_ref,
498520
)
521+
sycl_queue_arg_pos = -2
499522

500523
def codegen(context, builder, sig, args):
501524
qref_payload: _QueueRefPayload = _get_queue_ref(
502-
context, builder, sig, args
525+
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
503526
)
504527
ary = alloc_empty_arrayobj(
505528
context, builder, sig, qref_payload.queue_ref, args
@@ -569,9 +592,11 @@ def impl_dpnp_ones(
569592
ty_retty_ref,
570593
)
571594

595+
sycl_queue_arg_pos = -2
596+
572597
def codegen(context, builder, sig, args):
573598
qref_payload: _QueueRefPayload = _get_queue_ref(
574-
context, builder, sig, args
599+
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
575600
)
576601
ary = alloc_empty_arrayobj(
577602
context, builder, sig, qref_payload.queue_ref, args
@@ -647,10 +672,11 @@ def impl_dpnp_full(
647672
ty_sycl_queue,
648673
ty_retty_ref,
649674
)
675+
sycl_queue_arg_pos = -2
650676

651677
def codegen(context, builder, sig, args):
652678
qref_payload: _QueueRefPayload = _get_queue_ref(
653-
context, builder, sig, args
679+
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
654680
)
655681
ary = alloc_empty_arrayobj(
656682
context, builder, sig, qref_payload.queue_ref, args
@@ -726,10 +752,17 @@ def impl_dpnp_empty_like(
726752
ty_sycl_queue,
727753
ty_retty_ref,
728754
)
755+
sycl_queue_arg_pos = -2
756+
array_arg_pos = 0
729757

730758
def codegen(context, builder, sig, args):
731759
qref_payload: _QueueRefPayload = _get_queue_ref(
732-
context, builder, sig, args
760+
context,
761+
builder,
762+
sig,
763+
args,
764+
sycl_queue_arg_pos=sycl_queue_arg_pos,
765+
array_arg_pos=array_arg_pos,
733766
)
734767

735768
ary = alloc_empty_arrayobj(
@@ -799,9 +832,17 @@ def impl_dpnp_zeros_like(
799832
ty_retty_ref,
800833
)
801834

835+
sycl_queue_arg_pos = -2
836+
array_arg_pos = 0
837+
802838
def codegen(context, builder, sig, args):
803839
qref_payload: _QueueRefPayload = _get_queue_ref(
804-
context, builder, sig, args
840+
context,
841+
builder,
842+
sig,
843+
args,
844+
sycl_queue_arg_pos=sycl_queue_arg_pos,
845+
array_arg_pos=array_arg_pos,
805846
)
806847
ary = alloc_empty_arrayobj(
807848
context, builder, sig, qref_payload.queue_ref, args, is_like=True
@@ -877,10 +918,17 @@ def impl_dpnp_ones_like(
877918
ty_sycl_queue,
878919
ty_retty_ref,
879920
)
921+
sycl_queue_arg_pos = -2
922+
array_arg_pos = 0
880923

881924
def codegen(context, builder, sig, args):
882925
qref_payload: _QueueRefPayload = _get_queue_ref(
883-
context, builder, sig, args
926+
context,
927+
builder,
928+
sig,
929+
args,
930+
sycl_queue_arg_pos=sycl_queue_arg_pos,
931+
array_arg_pos=array_arg_pos,
884932
)
885933
ary = alloc_empty_arrayobj(
886934
context, builder, sig, qref_payload.queue_ref, args, is_like=True
@@ -960,10 +1008,17 @@ def impl_dpnp_full_like(
9601008
ty_sycl_queue,
9611009
ty_retty_ref,
9621010
)
1011+
sycl_queue_arg_pos = -2
1012+
array_arg_pos = 0
9631013

9641014
def codegen(context, builder, sig, args):
9651015
qref_payload: _QueueRefPayload = _get_queue_ref(
966-
context, builder, sig, args
1016+
context,
1017+
builder,
1018+
sig,
1019+
args,
1020+
sycl_queue_arg_pos=sycl_queue_arg_pos,
1021+
array_arg_pos=array_arg_pos,
9671022
)
9681023
ary = alloc_empty_arrayobj(
9691024
context, builder, sig, qref_payload.queue_ref, args, is_like=True

0 commit comments

Comments
 (0)