Skip to content

Commit 8aff322

Browse files
author
Diptorup Deb
committed
Address review comments.
1 parent e17a55b commit 8aff322

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,15 @@ def __init__(self, dmm, fe_type):
6161
class DpnpNdArrayModel(StructModel):
6262
"""Data model for the DpnpNdArray type.
6363
64-
The data model for DpnpNdArray is similar to numb's ArrayModel used for
65-
the numba.types.Array type, with the additional field ``sycl_queue`. The
66-
`sycl_queue` attribute stores the pointer to the C++ sycl::queue object
67-
that was used to allocate memory for numba-dpex's native representation
68-
for an Python object inferred as a DpnpNdArray.
64+
DpnpNdArrayModel is used by the numba_dpex.types.DpnpNdArray type and
65+
abstracts the usmarystruct_t C type defined in
66+
numba_dpex.core.runtime._usmarraystruct.h.
67+
68+
The DpnpNdArrayModel differs from numba's ArrayModel by including an extra
69+
member sycl_queue that maps to _usmarraystruct.sycl_queue pointer. The
70+
_usmarraystruct.sycl_queue pointer stores the C++ sycl::queue pointer that
71+
was used to allocate the data for the dpnp.ndarray represented by an
72+
instance of _usmarraystruct.
6973
"""
7074

7175
def __init__(self, dmm, fe_type):
@@ -102,7 +106,6 @@ def flattened_field_count(self):
102106
):
103107
flattened_member_count += 1
104108
else:
105-
print(member, type(member))
106109
raise UnreachableError
107110

108111
return flattened_member_count

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def _get_queue_ref(
132132
queue_arg_ty, (types.misc.NoneType, types.misc.Omitted)
133133
) and isinstance(queue_arg_ty, DpctlSyclQueue):
134134
if not isinstance(queue_arg.type, llvmir.LiteralStructType):
135-
raise AssertionError
135+
raise AssertionError(
136+
"Expected the queue_arg to be an llvmir.LiteralStructType"
137+
)
136138
sycl_queue_dm = dpex_dmm.lookup(queue_arg_ty)
137139
queue_ref = builder.extract_value(
138140
queue_arg, sycl_queue_dm.get_field_position("queue_ref")
@@ -147,7 +149,9 @@ def _get_queue_ref(
147149
else:
148150
if not isinstance(queue_arg.type, llvmir.PointerType):
149151
# TODO: check if the pointer is null
150-
raise AssertionError
152+
raise AssertionError(
153+
"Expected the queue_arg to be an llvmir.PointerType"
154+
)
151155
ty_sycl_queue = sig.return_type.queue
152156
py_dpctl_sycl_queue = get_device_cached_queue(ty_sycl_queue.sycl_device)
153157
(queue_ref, py_dpctl_sycl_queue_addr, pyapi) = make_queue(
@@ -159,7 +163,15 @@ def _get_queue_ref(
159163

160164

161165
def _update_queue_attr(array, queue):
162-
"""Sets the sycl_queue member of an ArrayStruct."""
166+
"""Assigns the sycl_queue member of an usmarystruct_t instance.
167+
168+
After creating a new usmarystruct_t struct (e.g. in _empty_nd_impl) the
169+
members of the struct are populated by calling
170+
numba.np.arrayobj.populate_array. The populate_array function does not
171+
update the sycl_queue member as populate_array is written specifically for
172+
numba's arystruct_t type that does not have a sycl_queue member. The
173+
_update_queue_attr is a helper function to update the sycl_queue field.
174+
"""
163175

164176
attr = dict(sycl_queue=queue)
165177
for k, v in attr.items():

0 commit comments

Comments
 (0)