Skip to content

Commit dfc0ac8

Browse files
Pickling should preserve type of Python object
Previously it would always produced shared memory on unpickling.
1 parent 906d77b commit dfc0ac8

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

dpctl/_memory.pyx

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
5151
This is useful when `src_ptr` and `dest_ptr` are bound to incompatible
5252
SYCL contexts.
5353
"""
54+
# could also use numpy.empty((nbytes,), dtype="|u1")
5455
cdef unsigned char[::1] host_buf = bytearray(nbytes)
5556

5657
DPPLQueue_Memcpy(
@@ -69,6 +70,10 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
6970

7071

7172
cdef class _BufferData:
73+
"""
74+
Internal data struct populated from parsing
75+
`__sycl_usm_array_interface__` dictionary
76+
"""
7277
cdef DPPLSyclUSMRef p
7378
cdef int writeable
7479
cdef object dt
@@ -122,12 +127,24 @@ cdef class _BufferData:
122127
return buf
123128

124129

125-
def _to_memory(unsigned char [::1] b):
126-
"""Constructs Memory of the same size as the argument and
127-
copies data into it"""
128-
cdef Memory res = MemoryUSMShared(len(b))
130+
def _to_memory(unsigned char [::1] b, str usm_kind):
131+
"""
132+
Constructs Memory of the same size as the argument
133+
and copies data into it"""
134+
cdef Memory res
135+
136+
if (usm_kind == "shared"):
137+
res = MemoryUSMShared(len(b))
138+
elif (usm_kind == "device"):
139+
res = MemoryUSMDevice(len(b))
140+
elif (usm_kind == "host"):
141+
res = MemoryUSMHost(len(b))
142+
else:
143+
raise ValueError(
144+
"Unrecognized usm_kind={} stored in the "
145+
"pickle".format(usm_kind))
129146
res.copy_from_host(b)
130-
147+
131148
return res
132149

133150

@@ -245,7 +262,7 @@ cdef class Memory:
245262
return self.tobytes()
246263

247264
def __reduce__(self):
248-
return _to_memory, (self.copy_to_host(), )
265+
return _to_memory, (self.copy_to_host(), self.get_usm_type())
249266

250267
property __sycl_usm_array_interface__:
251268
def __get__ (self):

dpctl/tests/test_sycl_usm.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,19 @@ def test_pickling(self):
132132
mobj.copy_from_host(host_src_obj)
133133

134134
mobj_reconstructed = pickle.loads(pickle.dumps(mobj))
135-
self.assertEqual(mobj.tobytes(), mobj_reconstructed.tobytes())
136-
self.assertNotEqual(mobj._pointer, mobj_reconstructed._pointer)
135+
self.assertEqual(
136+
type(mobj), type(mobj_reconstructed), "Pickling should preserve type"
137+
)
138+
self.assertEqual(
139+
mobj.tobytes(),
140+
mobj_reconstructed.tobytes(),
141+
"Pickling should preserve buffer content"
142+
)
143+
self.assertNotEqual(
144+
mobj._pointer,
145+
mobj_reconstructed._pointer,
146+
"Pickling/unpickling changes pointer"
147+
)
137148

138149

139150
class TestMemoryUSMBase:

0 commit comments

Comments
 (0)