@@ -51,6 +51,7 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
51
51
This is useful when `src_ptr` and `dest_ptr` are bound to incompatible
52
52
SYCL contexts.
53
53
"""
54
+ # could also use numpy.empty((nbytes,), dtype="|u1")
54
55
cdef unsigned char [::1 ] host_buf = bytearray(nbytes)
55
56
56
57
DPPLQueue_Memcpy(
@@ -69,6 +70,10 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
69
70
70
71
71
72
cdef class _BufferData:
73
+ """
74
+ Internal data struct populated from parsing
75
+ `__sycl_usm_array_interface__` dictionary
76
+ """
72
77
cdef DPPLSyclUSMRef p
73
78
cdef int writeable
74
79
cdef object dt
@@ -122,12 +127,24 @@ cdef class _BufferData:
122
127
return buf
123
128
124
129
125
- def _to_memory (unsigned char [::1] b ):
126
- """ Constructs Memory of the same size as the argument and
127
- copies data into it"""
128
- cdef Memory res = MemoryUSMShared(len (b))
130
+ def _to_memory (unsigned char [::1] b , str usm_kind ):
131
+ """
132
+ Constructs Memory of the same size as the argument
133
+ and copies data into it"""
134
+ cdef Memory res
135
+
136
+ if (usm_kind == " shared" ):
137
+ res = MemoryUSMShared(len (b))
138
+ elif (usm_kind == " device" ):
139
+ res = MemoryUSMDevice(len (b))
140
+ elif (usm_kind == " host" ):
141
+ res = MemoryUSMHost(len (b))
142
+ else :
143
+ raise ValueError (
144
+ " Unrecognized usm_kind={} stored in the "
145
+ " pickle" .format(usm_kind))
129
146
res.copy_from_host(b)
130
-
147
+
131
148
return res
132
149
133
150
@@ -245,7 +262,7 @@ cdef class Memory:
245
262
return self .tobytes()
246
263
247
264
def __reduce__ (self ):
248
- return _to_memory, (self .copy_to_host(), )
265
+ return _to_memory, (self .copy_to_host(), self .get_usm_type() )
249
266
250
267
property __sycl_usm_array_interface__ :
251
268
def __get__ (self ):
0 commit comments