Skip to content

Commit 1bb30da

Browse files
author
Diptorup Deb
authored
Merge pull request #1092 from IntelPython/fix/kernel_resolve_dpnp_args
Simplify DpnpNdArray to USMNdArray conversion during kernel argument resolution.
2 parents 94cea55 + 583dd5b commit 1bb30da

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

numba_dpex/core/targets/kernel_target.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def resolve_argument_type(self, val):
6464
"""
6565
try:
6666
numba_type = typeof(val)
67-
py_type = type(numba_type)
6867

6968
if isinstance(numba_type, NpArrayType) and not isinstance(
7069
numba_type, USMNdArray
@@ -73,19 +72,28 @@ def resolve_argument_type(self, val):
7372
type=str(type(val)), value=val
7473
)
7574

76-
# XXX A kernel function has the spir_kernel ABI and requires
77-
# pointers to have an address space attribute. For this reason, the
78-
# UsmNdArray type uses a custom data model where the pointers are
79-
# address space casted to have a SYCL-specific address space value.
80-
# The DpnpNdArray type on the other hand is meant to be used inside
81-
# host functions and has Numba's array model as its data model.
82-
# If the value is a DpnpNdArray then use the ``to_usm_ndarray``
83-
# function to convert it into a UsmNdArray type rather than passing
84-
# it to the kernel as a DpnpNdArray. Thus, from a Numba typing
85-
# perspective dpnp.ndarrays cannot be directly passed to a kernel.
86-
if py_type is DpnpNdArray:
87-
suai_attrs = get_info_from_suai(val)
88-
return to_usm_ndarray(suai_attrs)
75+
# A cast from DpnpNdArray type to USMNdArray is needed for all
76+
# arguments of DpnpNdArray type. Although, DpnpNdArray derives from
77+
# USMNdArray the two types use different data models. USMNdArray
78+
# uses the numba_dpex.core.datamodel.models.ArrayModel data model
79+
# that defines all CPointer type members in the GLOBAL address
80+
# space. The DpnpNdArray uses Numba's default ArrayModel that does
81+
# not define pointers in any specific address space. For OpenCL HD
82+
# Graphics devices, defining a kernel function (spir_kernel calling
83+
# convention) with pointer arguments that have no address space
84+
# qualifier causes a run time crash. By casting the argument type
85+
# for parfor arguments from DpnpNdArray type to the USMNdArray type
86+
# the generated kernel always has an address space qualifier,
87+
# avoiding the issue on OpenCL HD graphics devices.
88+
if isinstance(numba_type, DpnpNdArray):
89+
return USMNdArray(
90+
ndim=numba_type.ndim,
91+
layout=numba_type.layout,
92+
dtype=numba_type.dtype,
93+
usm_type=numba_type.usm_type,
94+
queue=numba_type.queue,
95+
)
96+
8997
except ValueError:
9098
# When an array-like kernel argument is not recognized by
9199
# numba-dpex, this additional check sees if the array-like object

0 commit comments

Comments
 (0)