Skip to content

Commit 9a17afc

Browse files
Merge pull request #1789 from IntelPython/enhance-from_dlpack
Enhance from_dlpack to support imported kDLCPU data to kDLOneAPI
2 parents dd4c0c0 + 2661f51 commit 9a17afc

File tree

3 files changed

+281
-28
lines changed

3 files changed

+281
-28
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 151 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168168
stdlib.free(dlmv_tensor)
169169

170170

171-
cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
171+
cdef object _get_default_context(c_dpctl.SyclDevice dev):
172172
try:
173173
default_context = dev.sycl_platform.default_context
174174
except RuntimeError:
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178178
return default_context
179179

180180

181-
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
181+
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
182182
cdef DPCTLSyclDeviceRef pDRef = NULL
183183
cdef DPCTLSyclDeviceRef tDRef = NULL
184184
cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201201

202202
cdef int get_array_dlpack_device_id(
203203
usm_ndarray usm_ary
204-
) except *:
204+
) except -1:
205205
"""Finds ordinal number of the parent of device where array
206206
was allocated.
207207
"""
@@ -935,6 +935,32 @@ cpdef object from_dlpack_capsule(object py_caps):
935935
"The DLPack tensor resides on unsupported device."
936936
)
937937

938+
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939+
q = dev.sycl_queue
940+
np_ary = np.asarray(host_blob)
941+
dt = np_ary.dtype
942+
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
943+
Xusm_dtype = (
944+
"float32" if dt.char == "d" else "complex64"
945+
)
946+
else:
947+
Xusm_dtype = dt
948+
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
949+
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
950+
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
951+
return usm_ary
952+
953+
954+
# only cdef to make it private
955+
cdef object _create_device(object device, object dl_device):
956+
if isinstance(device, Device):
957+
return device
958+
elif isinstance(device, dpctl.SyclDevice):
959+
return Device.create_device(device)
960+
else:
961+
root_device = dpctl.SyclDevice(str(<int>dl_device[1]))
962+
return Device.create_device(root_device)
963+
938964

939965
def from_dlpack(x, /, *, device=None, copy=None):
940966
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -943,7 +969,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
943969
object ``x`` that implements ``__dlpack__`` protocol.
944970
945971
Args:
946-
x (Python object):
972+
x (object):
947973
A Python object representing an array that supports
948974
``__dlpack__`` protocol.
949975
device (Optional[str,
@@ -959,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
959985
returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
960986
2-tuple matching the format of the output of the ``__dlpack_device__``
961987
method, an integer enumerator representing the device type followed by
962-
an integer representing the index of the device.
988+
an integer representing the index of the device. The only supported
989+
:enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
963990
Default: ``None``.
964991
copy (bool, optional)
965992
Boolean indicating whether or not to copy the input.
@@ -1008,33 +1035,130 @@ def from_dlpack(x, /, *, device=None, copy=None):
10081035
10091036
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
10101037
X = dpt.from_dlpack(C)
1038+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
10111039
10121040
"""
1013-
if not hasattr(x, "__dlpack__"):
1014-
raise TypeError(
1015-
f"The argument of type {type(x)} does not implement "
1016-
"`__dlpack__` method."
1017-
)
1018-
dlpack_attr = getattr(x, "__dlpack__")
1019-
if not callable(dlpack_attr):
1041+
dlpack_attr = getattr(x, "__dlpack__", None)
1042+
dlpack_dev_attr = getattr(x, "__dlpack_device__", None)
1043+
if not callable(dlpack_attr) or not callable(dlpack_dev_attr):
10201044
raise TypeError(
10211045
f"The argument of type {type(x)} does not implement "
1022-
"`__dlpack__` method."
1046+
"`__dlpack__` and `__dlpack_device__` methods."
10231047
)
1024-
try:
1025-
# device is converted to a dlpack_device if necessary
1026-
dl_device = None
1027-
if device:
1028-
if isinstance(device, tuple):
1029-
dl_device = device
1048+
# device is converted to a dlpack_device if necessary
1049+
dl_device = None
1050+
if device:
1051+
if isinstance(device, tuple):
1052+
dl_device = device
1053+
if len(dl_device) != 2:
1054+
raise ValueError(
1055+
"Argument `device` specified as a tuple must have length 2"
1056+
)
1057+
else:
1058+
if not isinstance(device, dpctl.SyclDevice):
1059+
device = Device.create_device(device)
1060+
d = device.sycl_device
10301061
else:
1031-
if not isinstance(device, dpctl.SyclDevice):
1032-
d = Device.create_device(device).sycl_device
1033-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1034-
else:
1035-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>device))
1036-
dlpack_capsule = dlpack_attr(max_version=get_build_dlpack_version(), dl_device=dl_device, copy=copy)
1037-
return from_dlpack_capsule(dlpack_capsule)
1062+
d = device
1063+
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1064+
if dl_device is not None:
1065+
if (dl_device[0] not in [device_OneAPI, device_CPU]):
1066+
raise ValueError(
1067+
f"Argument `device`={device} is not supported."
1068+
)
1069+
got_type_error = False
1070+
got_buffer_error = False
1071+
got_other_error = False
1072+
saved_exception = None
1073+
# First DLPack version supporting dl_device, and copy
1074+
requested_ver = (1, 0)
1075+
cpu_dev = (device_CPU, 0)
1076+
try:
1077+
# setting max_version to minimal version that supports dl_device/copy keywords
1078+
dlpack_capsule = dlpack_attr(
1079+
max_version=requested_ver,
1080+
dl_device=dl_device,
1081+
copy=copy
1082+
)
10381083
except TypeError:
1039-
dlpack_capsule = dlpack_attr()
1084+
# exporter does not support max_version keyword
1085+
got_type_error = True
1086+
except (BufferError, NotImplementedError):
1087+
# Either dl_device, or copy can be satisfied
1088+
got_buffer_error = True
1089+
except Exception as e:
1090+
got_other_error = True
1091+
saved_exception = e
1092+
else:
1093+
# execution did not raise exceptions
10401094
return from_dlpack_capsule(dlpack_capsule)
1095+
finally:
1096+
if got_type_error:
1097+
# max_version/dl_device, copy keywords are not supported by __dlpack__
1098+
x_dldev = dlpack_dev_attr()
1099+
if (dl_device is None) or (dl_device == x_dldev):
1100+
dlpack_capsule = dlpack_attr()
1101+
return from_dlpack_capsule(dlpack_capsule)
1102+
# must copy via host
1103+
if copy is False:
1104+
raise BufferError(
1105+
"Importing data via DLPack requires copying, but copy=False was provided"
1106+
)
1107+
# when max_version/dl_device/copy are not supported
1108+
# we can only support importing to OneAPI devices
1109+
# from host, or from another oneAPI device
1110+
is_supported_x_dldev = (
1111+
x_dldev == cpu_dev or
1112+
(x_dldev[0] == device_OneAPI)
1113+
)
1114+
is_supported_dl_device = (
1115+
dl_device == cpu_dev or
1116+
dl_device[0] == device_OneAPI
1117+
)
1118+
if is_supported_x_dldev and is_supported_dl_device:
1119+
dlpack_capsule = dlpack_attr()
1120+
blob = from_dlpack_capsule(dlpack_capsule)
1121+
else:
1122+
raise BufferError(f"Can not import to requested device {dl_device}")
1123+
dev = _create_device(device, dl_device)
1124+
if x_dldev == cpu_dev and dl_device == cpu_dev:
1125+
# both source and destination are CPU
1126+
return blob
1127+
elif x_dldev == cpu_dev:
1128+
# source is CPU, destination is oneAPI
1129+
return _to_usm_ary_from_host_blob(blob, dev)
1130+
elif dl_device == cpu_dev:
1131+
# source is oneAPI, destination is CPU
1132+
cpu_caps = blob.__dlpack__(
1133+
max_version=get_build_dlpack_version(),
1134+
dl_device=cpu_dev
1135+
)
1136+
return from_dlpack_capsule(cpu_caps)
1137+
else:
1138+
import dpctl.tensor as dpt
1139+
return dpt.asarray(blob, device=dev)
1140+
elif got_buffer_error:
1141+
# we are here, because dlpack_attr could not deal with requested dl_device,
1142+
# or copying was required
1143+
if copy is False:
1144+
raise BufferError(
1145+
"Importing data via DLPack requires copying, but copy=False was provided"
1146+
)
1147+
# must copy via host
1148+
if dl_device[0] != device_OneAPI:
1149+
raise BufferError(f"Can not import to requested device {dl_device}")
1150+
x_dldev = dlpack_dev_attr()
1151+
if x_dldev == cpu_dev:
1152+
dlpack_capsule = dlpack_attr()
1153+
host_blob = from_dlpack_capsule(dlpack_capsule)
1154+
else:
1155+
dlpack_capsule = dlpack_attr(
1156+
max_version=requested_ver,
1157+
dl_device=cpu_dev,
1158+
copy=copy
1159+
)
1160+
host_blob = from_dlpack_capsule(dlpack_capsule)
1161+
dev = _create_device(device, dl_device)
1162+
return _to_usm_ary_from_host_blob(host_blob, dev)
1163+
elif got_other_error:
1164+
raise saved_exception

dpctl/tensor/_usmarray.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ cdef class usm_ndarray:
12421242
_arr.flags["W"] = self.flags["W"]
12431243
return c_dlpack.numpy_to_dlpack_versioned_capsule(_arr, True)
12441244
else:
1245-
raise NotImplementedError(
1245+
raise BufferError(
12461246
f"targeting `dl_device` {dl_device} with `__dlpack__` is not "
12471247
"yet implemented"
12481248
)

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,132 @@ def test_dlpack_size_0_on_kdlcpu():
696696
cap = x_np.__dlpack__()
697697
y = _dlp.from_dlpack_capsule(cap)
698698
assert y.ctypes.data == x_np.ctypes.data
699+
700+
701+
def test_copy_via_host():
702+
get_queue_or_skip()
703+
x = dpt.ones(1, dtype="i4")
704+
x_np = np.ones(1, dtype="i4")
705+
x_dl_dev = x.__dlpack_device__()
706+
y = dpt.from_dlpack(x_np, device=x_dl_dev)
707+
assert isinstance(y, dpt.usm_ndarray)
708+
assert y.sycl_device == x.sycl_device
709+
assert y.usm_type == "device"
710+
711+
with pytest.raises(ValueError):
712+
# uncorrect length of tuple
713+
dpt.from_dlpack(x_np, device=(1, 0, 0))
714+
with pytest.raises(ValueError):
715+
# only kDLCPU and kDLOneAPI are supported
716+
dpt.from_dlpack(x, device=(2, 0))
717+
718+
num_devs = dpctl.get_num_devices()
719+
if num_devs > 1:
720+
j = [i for i in range(num_devs) if i != x_dl_dev[1]][0]
721+
z = dpt.from_dlpack(x, device=(x_dl_dev[0], j))
722+
assert isinstance(z, dpt.usm_ndarray)
723+
assert z.usm_type == "device"
724+
725+
726+
def test_copy_via_host_gh_1789():
727+
"Test based on review example from gh-1789"
728+
get_queue_or_skip()
729+
x_np = np.ones((10, 10), dtype="i4")
730+
# strides are no longer multiple of itemsize
731+
x_np.strides = (x_np.strides[0] - 1, x_np.strides[1])
732+
with pytest.raises(BufferError):
733+
dpt.from_dlpack(x_np)
734+
with pytest.raises(BufferError):
735+
dpt.from_dlpack(x_np, device=(14, 0))
736+
737+
738+
class LegacyContainer:
739+
"Helper class implementing legacy `__dlpack__` protocol"
740+
741+
def __init__(self, array):
742+
self._array = array
743+
744+
def __dlpack__(self, stream=None):
745+
return self._array.__dlpack__(stream=stream)
746+
747+
def __dlpack_device__(self):
748+
return self._array.__dlpack_device__()
749+
750+
751+
class Container:
752+
"Helper class implementing legacy `__dlpack__` protocol"
753+
754+
def __init__(self, array):
755+
self._array = array
756+
757+
def __dlpack__(
758+
self, max_version=None, dl_device=None, copy=None, stream=None
759+
):
760+
return self._array.__dlpack__(
761+
max_version=max_version,
762+
dl_device=dl_device,
763+
copy=copy,
764+
stream=stream,
765+
)
766+
767+
def __dlpack_device__(self):
768+
return self._array.__dlpack_device__()
769+
770+
771+
def test_generic_container_legacy():
772+
get_queue_or_skip()
773+
C = LegacyContainer(dpt.linspace(0, 100, num=20, dtype="int16"))
774+
775+
X = dpt.from_dlpack(C)
776+
assert isinstance(X, dpt.usm_ndarray)
777+
assert X._pointer == C._array._pointer
778+
assert X.sycl_device == C._array.sycl_device
779+
assert X.dtype == C._array.dtype
780+
781+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
782+
assert isinstance(Y, np.ndarray)
783+
assert Y.dtype == X.dtype
784+
785+
Z = dpt.from_dlpack(C, device=X.device)
786+
assert isinstance(Z, dpt.usm_ndarray)
787+
assert Z._pointer == X._pointer
788+
assert Z.device == X.device
789+
790+
791+
def test_generic_container_legacy_np():
792+
get_queue_or_skip()
793+
C = LegacyContainer(np.linspace(0, 100, num=20, dtype="int16"))
794+
795+
X = dpt.from_dlpack(C)
796+
assert isinstance(X, np.ndarray)
797+
assert X.ctypes.data == C._array.ctypes.data
798+
assert X.dtype == C._array.dtype
799+
800+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
801+
assert isinstance(Y, np.ndarray)
802+
assert Y.dtype == X.dtype
803+
804+
dev = dpt.Device.create_device()
805+
Z = dpt.from_dlpack(C, device=dev)
806+
assert isinstance(Z, dpt.usm_ndarray)
807+
assert Z.device == dev
808+
809+
810+
def test_generic_container():
811+
get_queue_or_skip()
812+
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
813+
814+
X = dpt.from_dlpack(C)
815+
assert isinstance(X, dpt.usm_ndarray)
816+
assert X._pointer == C._array._pointer
817+
assert X.sycl_device == C._array.sycl_device
818+
assert X.dtype == C._array.dtype
819+
820+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
821+
assert isinstance(Y, np.ndarray)
822+
assert Y.dtype == X.dtype
823+
824+
Z = dpt.from_dlpack(C, device=X.device)
825+
assert isinstance(Z, dpt.usm_ndarray)
826+
assert Z._pointer == X._pointer
827+
assert Z.device == X.device

0 commit comments

Comments
 (0)