Skip to content

Commit 4877352

Browse files
author
Diptorup Deb
committed
Change the signature for _get_queue_ref.
- Addresses the review comment to pass required arguments to _get_queue_ref explicitly.
1 parent 8aff322 commit 4877352

File tree

1 file changed

+101
-61
lines changed

1 file changed

+101
-61
lines changed

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 101 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"QueueRefPayload", ["queue_ref", "py_dpctl_sycl_queue_addr", "pyapi"]
3131
)
3232

33+
_ArgTyAndValue = namedtuple("ArgTyAndValue", ["numba_ty", "llvmir_val"])
34+
3335

3436
# XXX: The function should be moved into DpexTargetContext
3537
def make_queue(context, builder, py_dpctl_sycl_queue):
@@ -82,37 +84,39 @@ def make_queue(context, builder, py_dpctl_sycl_queue):
8284

8385

8486
def _get_queue_ref(
85-
context, builder, sig, args, *, sycl_queue_arg_pos, array_arg_pos=None
87+
context,
88+
builder,
89+
returned_sycl_queue_ty,
90+
sycl_queue_arg: _ArgTyAndValue,
91+
array_arg: _ArgTyAndValue = None,
8692
):
8793
"""Returns an LLVM IR Value pointer to a DpctlSyclQueueRef
8894
8995
The _get_queue_ref function is used by the intinsic functions that implement
9096
the overloads for dpnp array constructors: ``empty``, ``empty_like``,
9197
``zeros``, ``zeros_like``, ``ones``, ``ones_like``, ``full``, ``full_like``.
9298
93-
The args contains the list of LLVM IR values passed in to the dpnp
94-
overloads. The convention we follow is that the queue arg is always the
95-
penultimate arg passed to the intrinsic. For that reason, we can extract the
96-
queue argument as args[-2] and the type of the argument from the signature
97-
as sig.args[-2].
98-
99-
Depending on whether the ``sycl_queue`` argument was explicitly specified,
100-
or was omitted, the queue_arg will be either a DpctlSyclQueue type or a
101-
numba NoneType/Omitted type. If a DpctlSyclQueue, then we directly extract
102-
the queue_ref from the unboxed native struct representation of a
103-
dpctl.SyclQueue. If a queue was not explicitly provided and the type is
104-
NoneType/Omitted, we get a cached dpctl.SyclQueue from dpctl and unbox it
105-
on the fly and return the queue_ref.
99+
The function returns an LLVM IR Value corresponding to a dpctl.SyclQueue
100+
Python object's underlying ``_queue_ref`` pointer. If a non-None
101+
``sycl_queue_arg`` is provided, then the ``_queue_ref`` attribute is
102+
extracted from the ``sycl_queue_arg``. If the ``sycl_queue_arg`` is
103+
None or omitted and an ``array_arg`` is provided, then the ``_queue_ref``
104+
is extracted from the unboxed representation of the ``array_arg``. If
105+
nether a non-None ``sycl_queue_arg`` nor an ``array_arg`` is provided,
106+
then a cached dpctl.SyclQueue is retreived from dpctl and unboxed on the fly
107+
and the ``_queue_ref`` from that unboxed queue is returned to caller.
106108
107109
Args:
108110
context (numba.core.base.BaseContext): Any of the context
109111
derived from Numba's BaseContext
110112
(e.g. `numba.core.cpu.CPUContext`).
111113
builder (llvmlite.ir.builder.IRBuilder): The IR builder
112114
from `llvmlite` for code generation.
113-
sig: Signature of the overload function
114-
args (list): LLVM IR values corresponding to the args passed to the LLVM
115-
function created for a dpnp overload.
115+
returned_sycl_queue_ty: An instance of numba_dpex.types.DpctlSyclQueue
116+
sycl_queue_arg: A 2-tuple storing the numba inferred type and the
117+
corresponding LLVM IR value for a dpctl.SyclQueue Python object.
118+
array_arg: A 2-tuple storing the numba inferred type and the
119+
corresponding LLVM IR value for a dpnp.ndarray Python object.
116120
117121
Return:
118122
A namedtuple wrapping the queue_ref pointer, an optional address to
@@ -121,39 +125,39 @@ def _get_queue_ref(
121125
122126
"""
123127

124-
queue_arg = args[sycl_queue_arg_pos]
125-
queue_arg_ty = sig.args[sycl_queue_arg_pos]
126-
127128
queue_ref = None
128129
py_dpctl_sycl_queue_addr = None
129130
pyapi = None
130131

131132
if not isinstance(
132-
queue_arg_ty, (types.misc.NoneType, types.misc.Omitted)
133-
) and isinstance(queue_arg_ty, DpctlSyclQueue):
134-
if not isinstance(queue_arg.type, llvmir.LiteralStructType):
133+
sycl_queue_arg.numba_ty, (types.misc.NoneType, types.misc.Omitted)
134+
) and isinstance(sycl_queue_arg.numba_ty, DpctlSyclQueue):
135+
if not isinstance(
136+
sycl_queue_arg.llvmir_val.type, llvmir.LiteralStructType
137+
):
135138
raise AssertionError(
136139
"Expected the queue_arg to be an llvmir.LiteralStructType"
137140
)
138-
sycl_queue_dm = dpex_dmm.lookup(queue_arg_ty)
141+
sycl_queue_dm = dpex_dmm.lookup(sycl_queue_arg.numba_ty)
139142
queue_ref = builder.extract_value(
140-
queue_arg, sycl_queue_dm.get_field_position("queue_ref")
143+
sycl_queue_arg.llvmir_val,
144+
sycl_queue_dm.get_field_position("queue_ref"),
141145
)
142-
elif array_arg_pos is not None:
143-
array_arg = args[array_arg_pos]
144-
array_arg_ty = sig.args[array_arg_pos]
145-
dpnp_ndarray_dm = dpex_dmm.lookup(array_arg_ty)
146+
elif array_arg is not None:
147+
dpnp_ndarray_dm = dpex_dmm.lookup(array_arg.numba_ty)
146148
queue_ref = builder.extract_value(
147-
array_arg, dpnp_ndarray_dm.get_field_position("sycl_queue")
149+
array_arg.llvmir_val,
150+
dpnp_ndarray_dm.get_field_position("sycl_queue"),
148151
)
149152
else:
150-
if not isinstance(queue_arg.type, llvmir.PointerType):
153+
if not isinstance(sycl_queue_arg.llvmir_val.type, llvmir.PointerType):
151154
# TODO: check if the pointer is null
152155
raise AssertionError(
153156
"Expected the queue_arg to be an llvmir.PointerType"
154157
)
155-
ty_sycl_queue = sig.return_type.queue
156-
py_dpctl_sycl_queue = get_device_cached_queue(ty_sycl_queue.sycl_device)
158+
py_dpctl_sycl_queue = get_device_cached_queue(
159+
returned_sycl_queue_ty.sycl_device
160+
)
157161
(queue_ref, py_dpctl_sycl_queue_addr, pyapi) = make_queue(
158162
context, builder, py_dpctl_sycl_queue
159163
)
@@ -467,8 +471,14 @@ def impl_dpnp_empty(
467471
sycl_queue_arg_pos = -2
468472

469473
def codegen(context, builder, sig, args):
474+
sycl_queue_arg = _ArgTyAndValue(
475+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
476+
)
470477
qref_payload: _QueueRefPayload = _get_queue_ref(
471-
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
478+
context=context,
479+
builder=builder,
480+
returned_sycl_queue_ty=sig.return_type.queue,
481+
sycl_queue_arg=sycl_queue_arg,
472482
)
473483

474484
ary = alloc_empty_arrayobj(
@@ -533,8 +543,14 @@ def impl_dpnp_zeros(
533543
sycl_queue_arg_pos = -2
534544

535545
def codegen(context, builder, sig, args):
546+
sycl_queue_arg = _ArgTyAndValue(
547+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
548+
)
536549
qref_payload: _QueueRefPayload = _get_queue_ref(
537-
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
550+
context=context,
551+
builder=builder,
552+
returned_sycl_queue_ty=sig.return_type.queue,
553+
sycl_queue_arg=sycl_queue_arg,
538554
)
539555
ary = alloc_empty_arrayobj(
540556
context, builder, sig, qref_payload.queue_ref, args
@@ -607,8 +623,14 @@ def impl_dpnp_ones(
607623
sycl_queue_arg_pos = -2
608624

609625
def codegen(context, builder, sig, args):
626+
sycl_queue_arg = _ArgTyAndValue(
627+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
628+
)
610629
qref_payload: _QueueRefPayload = _get_queue_ref(
611-
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
630+
context=context,
631+
builder=builder,
632+
returned_sycl_queue_ty=sig.return_type.queue,
633+
sycl_queue_arg=sycl_queue_arg,
612634
)
613635
ary = alloc_empty_arrayobj(
614636
context, builder, sig, qref_payload.queue_ref, args
@@ -687,8 +709,14 @@ def impl_dpnp_full(
687709
sycl_queue_arg_pos = -2
688710

689711
def codegen(context, builder, sig, args):
712+
sycl_queue_arg = _ArgTyAndValue(
713+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
714+
)
690715
qref_payload: _QueueRefPayload = _get_queue_ref(
691-
context, builder, sig, args, sycl_queue_arg_pos=sycl_queue_arg_pos
716+
context=context,
717+
builder=builder,
718+
returned_sycl_queue_ty=sig.return_type.queue,
719+
sycl_queue_arg=sycl_queue_arg,
692720
)
693721
ary = alloc_empty_arrayobj(
694722
context, builder, sig, qref_payload.queue_ref, args
@@ -768,13 +796,16 @@ def impl_dpnp_empty_like(
768796
array_arg_pos = 0
769797

770798
def codegen(context, builder, sig, args):
799+
sycl_queue_arg = _ArgTyAndValue(
800+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
801+
)
802+
array_arg = _ArgTyAndValue(sig.args[array_arg_pos], args[array_arg_pos])
771803
qref_payload: _QueueRefPayload = _get_queue_ref(
772-
context,
773-
builder,
774-
sig,
775-
args,
776-
sycl_queue_arg_pos=sycl_queue_arg_pos,
777-
array_arg_pos=array_arg_pos,
804+
context=context,
805+
builder=builder,
806+
returned_sycl_queue_ty=sig.return_type.queue,
807+
sycl_queue_arg=sycl_queue_arg,
808+
array_arg=array_arg,
778809
)
779810

780811
ary = alloc_empty_arrayobj(
@@ -848,13 +879,16 @@ def impl_dpnp_zeros_like(
848879
array_arg_pos = 0
849880

850881
def codegen(context, builder, sig, args):
882+
sycl_queue_arg = _ArgTyAndValue(
883+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
884+
)
885+
array_arg = _ArgTyAndValue(sig.args[array_arg_pos], args[array_arg_pos])
851886
qref_payload: _QueueRefPayload = _get_queue_ref(
852-
context,
853-
builder,
854-
sig,
855-
args,
856-
sycl_queue_arg_pos=sycl_queue_arg_pos,
857-
array_arg_pos=array_arg_pos,
887+
context=context,
888+
builder=builder,
889+
returned_sycl_queue_ty=sig.return_type.queue,
890+
sycl_queue_arg=sycl_queue_arg,
891+
array_arg=array_arg,
858892
)
859893
ary = alloc_empty_arrayobj(
860894
context, builder, sig, qref_payload.queue_ref, args, is_like=True
@@ -934,13 +968,16 @@ def impl_dpnp_ones_like(
934968
array_arg_pos = 0
935969

936970
def codegen(context, builder, sig, args):
971+
sycl_queue_arg = _ArgTyAndValue(
972+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
973+
)
974+
array_arg = _ArgTyAndValue(sig.args[array_arg_pos], args[array_arg_pos])
937975
qref_payload: _QueueRefPayload = _get_queue_ref(
938-
context,
939-
builder,
940-
sig,
941-
args,
942-
sycl_queue_arg_pos=sycl_queue_arg_pos,
943-
array_arg_pos=array_arg_pos,
976+
context=context,
977+
builder=builder,
978+
returned_sycl_queue_ty=sig.return_type.queue,
979+
sycl_queue_arg=sycl_queue_arg,
980+
array_arg=array_arg,
944981
)
945982
ary = alloc_empty_arrayobj(
946983
context, builder, sig, qref_payload.queue_ref, args, is_like=True
@@ -1024,13 +1061,16 @@ def impl_dpnp_full_like(
10241061
array_arg_pos = 0
10251062

10261063
def codegen(context, builder, sig, args):
1064+
sycl_queue_arg = _ArgTyAndValue(
1065+
sig.args[sycl_queue_arg_pos], args[sycl_queue_arg_pos]
1066+
)
1067+
array_arg = _ArgTyAndValue(sig.args[array_arg_pos], args[array_arg_pos])
10271068
qref_payload: _QueueRefPayload = _get_queue_ref(
1028-
context,
1029-
builder,
1030-
sig,
1031-
args,
1032-
sycl_queue_arg_pos=sycl_queue_arg_pos,
1033-
array_arg_pos=array_arg_pos,
1069+
context=context,
1070+
builder=builder,
1071+
returned_sycl_queue_ty=sig.return_type.queue,
1072+
sycl_queue_arg=sycl_queue_arg,
1073+
array_arg=array_arg,
10341074
)
10351075
ary = alloc_empty_arrayobj(
10361076
context, builder, sig, qref_payload.queue_ref, args, is_like=True

0 commit comments

Comments
 (0)