Skip to content

Commit 254696b

Browse files
authored
Merge pull request #33 from ngoldbaum/pickle-support
Add support for pickling ASCIIDtype and StringDType instances
2 parents ae73fb0 + 73c8f64 commit 254696b

File tree

7 files changed

+224
-17
lines changed

7 files changed

+224
-17
lines changed

asciidtype/asciidtype/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
from .scalar import ASCIIScalar # isort: skip
88
from ._asciidtype_main import ASCIIDType
99

10-
__all__ = ["ASCIIDType", "ASCIIScalar"]
10+
__all__ = [
11+
"ASCIIDType",
12+
"ASCIIScalar",
13+
]

asciidtype/asciidtype/src/dtype.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,82 @@ static PyMemberDef ASCIIDType_members[] = {
226226
{NULL},
227227
};
228228

229+
static int PICKLE_VERSION = 1;
230+
231+
static PyObject *
232+
asciidtype__reduce__(ASCIIDTypeObject *self)
233+
{
234+
PyObject *ret, *mod, *obj, *state;
235+
236+
ret = PyTuple_New(3);
237+
if (ret == NULL) {
238+
return NULL;
239+
}
240+
241+
mod = PyImport_ImportModule("asciidtype");
242+
if (mod == NULL) {
243+
Py_DECREF(ret);
244+
return NULL;
245+
}
246+
247+
obj = PyObject_GetAttrString(mod, "ASCIIDType");
248+
Py_DECREF(mod);
249+
if (obj == NULL) {
250+
Py_DECREF(ret);
251+
return NULL;
252+
}
253+
254+
PyTuple_SET_ITEM(ret, 0, obj);
255+
256+
PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(l)", self->size));
257+
258+
state = PyTuple_New(1);
259+
260+
PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));
261+
262+
PyTuple_SET_ITEM(ret, 2, state);
263+
264+
return ret;
265+
}
266+
267+
static PyObject *
268+
asciidtype__setstate__(ASCIIDTypeObject *NPY_UNUSED(self), PyObject *args)
269+
{
270+
if (PyTuple_GET_SIZE(args) != 1 ||
271+
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
272+
PyErr_BadInternalCall();
273+
return NULL;
274+
}
275+
276+
long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));
277+
278+
if (version != PICKLE_VERSION) {
279+
PyErr_Format(PyExc_ValueError,
280+
"Pickle version mismatch. Got version %d but expected "
281+
"version %d.",
282+
version, PICKLE_VERSION);
283+
return NULL;
284+
}
285+
286+
Py_RETURN_NONE;
287+
}
288+
289+
static PyMethodDef ASCIIDType_methods[] = {
290+
{
291+
"__reduce__",
292+
(PyCFunction)asciidtype__reduce__,
293+
METH_NOARGS,
294+
"Reduction method for an ASCIIDType object",
295+
},
296+
{
297+
"__setstate__",
298+
(PyCFunction)asciidtype__setstate__,
299+
METH_O,
300+
"Unpickle an ASCIIDType object",
301+
},
302+
{NULL},
303+
};
304+
229305
/*
230306
* This is the basic things that you need to create a Python Type/Class in C.
231307
* However, there is a slight difference here because we create a
@@ -242,6 +318,7 @@ PyArray_DTypeMeta ASCIIDType = {
242318
.tp_repr = (reprfunc)asciidtype_repr,
243319
.tp_str = (reprfunc)asciidtype_repr,
244320
.tp_members = ASCIIDType_members,
321+
.tp_methods = ASCIIDType_methods,
245322
}},
246323
/* rest, filled in during DTypeMeta initialization */
247324
};

asciidtype/tests/test_asciidtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import os
2+
import pickle
13
import re
4+
import tempfile
25

36
import numpy as np
47
import pytest
@@ -230,3 +233,18 @@ def test_insert_scalar_directly():
230233
val = arr[0]
231234
arr[1] = val
232235
np.testing.assert_array_equal(arr, np.array(["some", "some"], dtype=dtype))
236+
237+
238+
def test_pickle():
239+
dtype = ASCIIDType(6)
240+
arr = np.array(["this", "is", "an", "array"], dtype=dtype)
241+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
242+
pickle.dump([arr, dtype], f)
243+
244+
with open(f.name, "rb") as f:
245+
res = pickle.load(f)
246+
247+
np.testing.assert_array_equal(arr, res[0])
248+
assert res[1] == dtype
249+
250+
os.remove(f.name)

stringdtype/stringdtype/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
"""A dtype for working with string data
1+
"""A dtype for working with variable-length string data
22
3-
This is an example usage of the experimental new dtype API
4-
in Numpy and is not intended for any real purpose.
53
"""
64

75
from .scalar import StringScalar # isort: skip
86
from ._main import StringDType, _memory_usage
97

8+
109
__all__ = [
1110
"StringDType",
1211
"StringScalar",

stringdtype/stringdtype/scalar.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33

44
class StringScalar(str):
5-
def __new__(cls, value, dtype):
6-
instance = super().__new__(cls, value)
7-
instance.dtype = dtype
8-
return instance
9-
105
def partition(self, sep):
116
ret = super().partition(sep)
127
return (str(ret[0]), str(ret[1]), str(ret[2]))

stringdtype/stringdtype/src/dtype.c

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ new_stringdtype_instance(void)
1919
new->base.elsize = sizeof(ss *);
2020
new->base.alignment = _Alignof(ss *);
2121
new->base.flags |= NPY_NEEDS_INIT;
22+
new->base.flags |= NPY_LIST_PICKLE;
2223

2324
return new;
2425
}
@@ -68,7 +69,7 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
6869
return NULL;
6970
}
7071

71-
PyArray_Descr *ret = (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
72+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
7273
if (ret == NULL) {
7374
return NULL;
7475
}
@@ -143,7 +144,7 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
143144
}
144145

145146
PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)StringScalar_Type,
146-
val_obj, descr, NULL);
147+
val_obj, NULL);
147148

148149
if (res == NULL) {
149150
return NULL;
@@ -200,6 +201,82 @@ stringdtype_repr(StringDTypeObject *NPY_UNUSED(self))
200201
return PyUnicode_FromString("StringDType()");
201202
}
202203

204+
static int PICKLE_VERSION = 1;
205+
206+
static PyObject *
207+
stringdtype__reduce__(StringDTypeObject *NPY_UNUSED(self))
208+
{
209+
PyObject *ret, *mod, *obj, *state;
210+
211+
ret = PyTuple_New(3);
212+
if (ret == NULL) {
213+
return NULL;
214+
}
215+
216+
mod = PyImport_ImportModule("stringdtype");
217+
if (mod == NULL) {
218+
Py_DECREF(ret);
219+
return NULL;
220+
}
221+
222+
obj = PyObject_GetAttrString(mod, "StringDType");
223+
Py_DECREF(mod);
224+
if (obj == NULL) {
225+
Py_DECREF(ret);
226+
return NULL;
227+
}
228+
229+
PyTuple_SET_ITEM(ret, 0, obj);
230+
231+
PyTuple_SET_ITEM(ret, 1, PyTuple_New(0));
232+
233+
state = PyTuple_New(1);
234+
235+
PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));
236+
237+
PyTuple_SET_ITEM(ret, 2, state);
238+
239+
return ret;
240+
}
241+
242+
static PyObject *
243+
stringdtype__setstate__(StringDTypeObject *NPY_UNUSED(self), PyObject *args)
244+
{
245+
if (PyTuple_GET_SIZE(args) != 1 ||
246+
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
247+
PyErr_BadInternalCall();
248+
return NULL;
249+
}
250+
251+
long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));
252+
253+
if (version != PICKLE_VERSION) {
254+
PyErr_Format(PyExc_ValueError,
255+
"Pickle version mismatch. Got version %d but expected "
256+
"version %d.",
257+
version, PICKLE_VERSION);
258+
return NULL;
259+
}
260+
261+
Py_RETURN_NONE;
262+
}
263+
264+
static PyMethodDef StringDType_methods[] = {
265+
{
266+
"__reduce__",
267+
(PyCFunction)stringdtype__reduce__,
268+
METH_NOARGS,
269+
"Reduction method for an StringDType object",
270+
},
271+
{
272+
"__setstate__",
273+
(PyCFunction)stringdtype__setstate__,
274+
METH_O,
275+
"Unpickle an StringDType object",
276+
},
277+
{NULL},
278+
};
279+
203280
/*
204281
* This is the basic things that you need to create a Python Type/Class in C.
205282
* However, there is a slight difference here because we create a
@@ -215,6 +292,7 @@ PyArray_DTypeMeta StringDType = {
215292
.tp_dealloc = (destructor)stringdtype_dealloc,
216293
.tp_repr = (reprfunc)stringdtype_repr,
217294
.tp_str = (reprfunc)stringdtype_repr,
295+
.tp_methods = StringDType_methods,
218296
}},
219297
/* rest, filled in during DTypeMeta initialization */
220298
};

stringdtype/tests/test_stringdtype.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import concurrent.futures
2+
import os
3+
import pickle
4+
import tempfile
5+
16
import numpy as np
27
import pytest
38

@@ -10,7 +15,7 @@ def string_list():
1015

1116

1217
def test_scalar_creation():
13-
assert str(StringScalar("abc", StringDType())) == "abc"
18+
assert str(StringScalar("abc")) == "abc"
1419

1520

1621
def test_dtype_creation():
@@ -38,12 +43,11 @@ def test_array_creation_utf8(data):
3843

3944

4045
def test_array_creation_scalars(string_list):
41-
dtype = StringDType()
4246
arr = np.array(
4347
[
44-
StringScalar("abc", dtype=dtype),
45-
StringScalar("def", dtype=dtype),
46-
StringScalar("ghi", dtype=dtype),
48+
StringScalar("abc"),
49+
StringScalar("def"),
50+
StringScalar("ghi"),
4751
]
4852
)
4953
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
@@ -94,7 +98,7 @@ def test_unicode_casts(string_list):
9498
def test_insert_scalar(string_list):
9599
dtype = StringDType()
96100
arr = np.array(string_list, dtype=dtype)
97-
arr[1] = StringScalar("what", dtype=dtype)
101+
arr[1] = StringScalar("what")
98102
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))
99103

100104

@@ -124,3 +128,36 @@ def test_memory_usage(string_list):
124128
_memory_usage("hello")
125129
with pytest.raises(TypeError):
126130
_memory_usage(np.array([1, 2, 3]))
131+
132+
133+
def _pickle_load(filename):
134+
with open(filename, "rb") as f:
135+
res = pickle.load(f)
136+
137+
return res
138+
139+
140+
def test_pickle(string_list):
141+
dtype = StringDType()
142+
143+
arr = np.array(string_list, dtype=dtype)
144+
145+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
146+
pickle.dump([arr, dtype], f)
147+
148+
with open(f.name, "rb") as f:
149+
res = pickle.load(f)
150+
151+
np.testing.assert_array_equal(res[0], arr)
152+
assert res[1] == dtype
153+
154+
# load the pickle in a subprocess to ensure the string data are
155+
# actually stored in the pickle file
156+
with concurrent.futures.ProcessPoolExecutor() as executor:
157+
e = executor.submit(_pickle_load, f.name)
158+
res = e.result()
159+
160+
np.testing.assert_array_equal(res[0], arr)
161+
assert res[1] == dtype
162+
163+
os.remove(f.name)

0 commit comments

Comments
 (0)