@@ -168,7 +168,7 @@ cdef void _managed_tensor_versioned_deleter(DLManagedTensorVersioned *dlmv_tenso
168
168
stdlib.free(dlmv_tensor)
169
169
170
170
171
- cdef object _get_default_context(c_dpctl.SyclDevice dev) except * :
171
+ cdef object _get_default_context(c_dpctl.SyclDevice dev):
172
172
try :
173
173
default_context = dev.sycl_platform.default_context
174
174
except RuntimeError :
@@ -178,7 +178,7 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
178
178
return default_context
179
179
180
180
181
- cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except * :
181
+ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except - 1 :
182
182
cdef DPCTLSyclDeviceRef pDRef = NULL
183
183
cdef DPCTLSyclDeviceRef tDRef = NULL
184
184
cdef c_dpctl.SyclDevice p_dev
@@ -201,7 +201,7 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
201
201
202
202
cdef int get_array_dlpack_device_id(
203
203
usm_ndarray usm_ary
204
- ) except * :
204
+ ) except - 1 :
205
205
""" Finds ordinal number of the parent of device where array
206
206
was allocated.
207
207
"""
@@ -935,6 +935,32 @@ cpdef object from_dlpack_capsule(object py_caps):
935
935
" The DLPack tensor resides on unsupported device."
936
936
)
937
937
938
+ cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
939
+ q = dev.sycl_queue
940
+ np_ary = np.asarray(host_blob)
941
+ dt = np_ary.dtype
942
+ if dt.char in " dD" and q.sycl_device.has_aspect_fp64 is False :
943
+ Xusm_dtype = (
944
+ " float32" if dt.char == " d" else " complex64"
945
+ )
946
+ else :
947
+ Xusm_dtype = dt
948
+ usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue = q)
949
+ usm_ary = usm_ndarray(np_ary.shape, dtype = Xusm_dtype, buffer = usm_mem)
950
+ usm_mem.copy_from_host(np.reshape(np_ary.view(dtype = " u1" ), - 1 ))
951
+ return usm_ary
952
+
953
+
954
+ # only cdef to make it private
955
+ cdef object _create_device(object device, object dl_device):
956
+ if isinstance (device, Device):
957
+ return device
958
+ elif isinstance (device, dpctl.SyclDevice):
959
+ return Device.create_device(device)
960
+ else :
961
+ root_device = dpctl.SyclDevice(str (< int > dl_device[1 ]))
962
+ return Device.create_device(root_device)
963
+
938
964
939
965
def from_dlpack (x , /, *, device = None , copy = None ):
940
966
""" from_dlpack(x, /, *, device=None, copy=None)
@@ -943,7 +969,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
943
969
object ``x`` that implements ``__dlpack__`` protocol.
944
970
945
971
Args:
946
- x (Python object):
972
+ x (object):
947
973
A Python object representing an array that supports
948
974
``__dlpack__`` protocol.
949
975
device (Optional[str,
@@ -959,7 +985,8 @@ def from_dlpack(x, /, *, device=None, copy=None):
959
985
returned by :attr:`dpctl.tensor.usm_ndarray.device`, or a
960
986
2-tuple matching the format of the output of the ``__dlpack_device__``
961
987
method, an integer enumerator representing the device type followed by
962
- an integer representing the index of the device.
988
+ an integer representing the index of the device. The only supported
989
+ :enum:`dpctl.tensor.DLDeviceType` types are "kDLCPU" and "kDLOneAPI".
963
990
Default: ``None``.
964
991
copy (bool, optional)
965
992
Boolean indicating whether or not to copy the input.
@@ -1008,33 +1035,130 @@ def from_dlpack(x, /, *, device=None, copy=None):
1008
1035
1009
1036
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
1010
1037
X = dpt.from_dlpack(C)
1038
+ Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
1011
1039
1012
1040
"""
1013
- if not hasattr (x, " __dlpack__" ):
1014
- raise TypeError (
1015
- f" The argument of type {type(x)} does not implement "
1016
- " `__dlpack__` method."
1017
- )
1018
- dlpack_attr = getattr (x, " __dlpack__" )
1019
- if not callable (dlpack_attr):
1041
+ dlpack_attr = getattr (x, " __dlpack__" , None )
1042
+ dlpack_dev_attr = getattr (x, " __dlpack_device__" , None )
1043
+ if not callable (dlpack_attr) or not callable (dlpack_dev_attr):
1020
1044
raise TypeError (
1021
1045
f" The argument of type {type(x)} does not implement "
1022
- " `__dlpack__` method ."
1046
+ " `__dlpack__` and `__dlpack_device__` methods ."
1023
1047
)
1024
- try :
1025
- # device is converted to a dlpack_device if necessary
1026
- dl_device = None
1027
- if device:
1028
- if isinstance (device, tuple ):
1029
- dl_device = device
1048
+ # device is converted to a dlpack_device if necessary
1049
+ dl_device = None
1050
+ if device:
1051
+ if isinstance (device, tuple ):
1052
+ dl_device = device
1053
+ if len (dl_device) != 2 :
1054
+ raise ValueError (
1055
+ " Argument `device` specified as a tuple must have length 2"
1056
+ )
1057
+ else :
1058
+ if not isinstance (device, dpctl.SyclDevice):
1059
+ device = Device.create_device(device)
1060
+ d = device.sycl_device
1030
1061
else :
1031
- if not isinstance (device, dpctl.SyclDevice):
1032
- d = Device.create_device(device).sycl_device
1033
- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1034
- else :
1035
- dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> device))
1036
- dlpack_capsule = dlpack_attr(max_version = get_build_dlpack_version(), dl_device = dl_device, copy = copy)
1037
- return from_dlpack_capsule(dlpack_capsule)
1062
+ d = device
1063
+ dl_device = (device_OneAPI, get_parent_device_ordinal_id(< c_dpctl.SyclDevice> d))
1064
+ if dl_device is not None :
1065
+ if (dl_device[0 ] not in [device_OneAPI, device_CPU]):
1066
+ raise ValueError (
1067
+ f" Argument `device`={device} is not supported."
1068
+ )
1069
+ got_type_error = False
1070
+ got_buffer_error = False
1071
+ got_other_error = False
1072
+ saved_exception = None
1073
+ # First DLPack version supporting dl_device, and copy
1074
+ requested_ver = (1 , 0 )
1075
+ cpu_dev = (device_CPU, 0 )
1076
+ try :
1077
+ # setting max_version to minimal version that supports dl_device/copy keywords
1078
+ dlpack_capsule = dlpack_attr(
1079
+ max_version = requested_ver,
1080
+ dl_device = dl_device,
1081
+ copy = copy
1082
+ )
1038
1083
except TypeError :
1039
- dlpack_capsule = dlpack_attr()
1084
+ # exporter does not support max_version keyword
1085
+ got_type_error = True
1086
+ except (BufferError, NotImplementedError ):
1087
+ # Either dl_device, or copy can be satisfied
1088
+ got_buffer_error = True
1089
+ except Exception as e:
1090
+ got_other_error = True
1091
+ saved_exception = e
1092
+ else :
1093
+ # execution did not raise exceptions
1040
1094
return from_dlpack_capsule(dlpack_capsule)
1095
+ finally :
1096
+ if got_type_error:
1097
+ # max_version/dl_device, copy keywords are not supported by __dlpack__
1098
+ x_dldev = dlpack_dev_attr()
1099
+ if (dl_device is None ) or (dl_device == x_dldev):
1100
+ dlpack_capsule = dlpack_attr()
1101
+ return from_dlpack_capsule(dlpack_capsule)
1102
+ # must copy via host
1103
+ if copy is False :
1104
+ raise BufferError(
1105
+ " Importing data via DLPack requires copying, but copy=False was provided"
1106
+ )
1107
+ # when max_version/dl_device/copy are not supported
1108
+ # we can only support importing to OneAPI devices
1109
+ # from host, or from another oneAPI device
1110
+ is_supported_x_dldev = (
1111
+ x_dldev == cpu_dev or
1112
+ (x_dldev[0 ] == device_OneAPI)
1113
+ )
1114
+ is_supported_dl_device = (
1115
+ dl_device == cpu_dev or
1116
+ dl_device[0 ] == device_OneAPI
1117
+ )
1118
+ if is_supported_x_dldev and is_supported_dl_device:
1119
+ dlpack_capsule = dlpack_attr()
1120
+ blob = from_dlpack_capsule(dlpack_capsule)
1121
+ else :
1122
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1123
+ dev = _create_device(device, dl_device)
1124
+ if x_dldev == cpu_dev and dl_device == cpu_dev:
1125
+ # both source and destination are CPU
1126
+ return blob
1127
+ elif x_dldev == cpu_dev:
1128
+ # source is CPU, destination is oneAPI
1129
+ return _to_usm_ary_from_host_blob(blob, dev)
1130
+ elif dl_device == cpu_dev:
1131
+ # source is oneAPI, destination is CPU
1132
+ cpu_caps = blob.__dlpack__(
1133
+ max_version = get_build_dlpack_version(),
1134
+ dl_device = cpu_dev
1135
+ )
1136
+ return from_dlpack_capsule(cpu_caps)
1137
+ else :
1138
+ import dpctl.tensor as dpt
1139
+ return dpt.asarray(blob, device = dev)
1140
+ elif got_buffer_error:
1141
+ # we are here, because dlpack_attr could not deal with requested dl_device,
1142
+ # or copying was required
1143
+ if copy is False :
1144
+ raise BufferError(
1145
+ " Importing data via DLPack requires copying, but copy=False was provided"
1146
+ )
1147
+ # must copy via host
1148
+ if dl_device[0 ] != device_OneAPI:
1149
+ raise BufferError(f" Can not import to requested device {dl_device}" )
1150
+ x_dldev = dlpack_dev_attr()
1151
+ if x_dldev == cpu_dev:
1152
+ dlpack_capsule = dlpack_attr()
1153
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1154
+ else :
1155
+ dlpack_capsule = dlpack_attr(
1156
+ max_version = requested_ver,
1157
+ dl_device = cpu_dev,
1158
+ copy = copy
1159
+ )
1160
+ host_blob = from_dlpack_capsule(dlpack_capsule)
1161
+ dev = _create_device(device, dl_device)
1162
+ return _to_usm_ary_from_host_blob(host_blob, dev)
1163
+ elif got_other_error:
1164
+ raise saved_exception
0 commit comments