Skip to content

Commit 98873db

Browse files
committed
add support for pickling StringDType instances
pickling arrays is currently broken and needs support in numpy to work
1 parent 5c96050 commit 98873db

File tree

3 files changed

+96
-3
lines changed

3 files changed

+96
-3
lines changed

stringdtype/stringdtype/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
"""A dtype for working with string data
1+
"""A dtype for working with variable-length string data
22
3-
This is an example usage of the experimental new dtype API
4-
in Numpy and is not intended for any real purpose.
53
"""
64

75
from .scalar import StringScalar # isort: skip

stringdtype/stringdtype/src/dtype.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,82 @@ stringdtype_repr(StringDTypeObject *NPY_UNUSED(self))
200200
return PyUnicode_FromString("StringDType()");
201201
}
202202

203+
static int PICKLE_VERSION = 1;
204+
205+
static PyObject *
206+
stringdtype__reduce__(StringDTypeObject *self)
207+
{
208+
PyObject *ret, *mod, *obj, *state;
209+
210+
ret = PyTuple_New(3);
211+
if (ret == NULL) {
212+
return NULL;
213+
}
214+
215+
mod = PyImport_ImportModule("stringdtype");
216+
if (mod == NULL) {
217+
Py_DECREF(ret);
218+
return NULL;
219+
}
220+
221+
obj = PyObject_GetAttrString(mod, "_reconstruct_StringDType");
222+
Py_DECREF(mod);
223+
if (obj == NULL) {
224+
Py_DECREF(ret);
225+
return NULL;
226+
}
227+
228+
PyTuple_SET_ITEM(ret, 0, obj);
229+
230+
PyTuple_SET_ITEM(ret, 1, PyTuple_New(0));
231+
232+
state = PyTuple_New(1);
233+
234+
PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));
235+
236+
PyTuple_SET_ITEM(ret, 2, state);
237+
238+
return ret;
239+
}
240+
241+
static PyObject *
242+
stringdtype__setstate__(StringDTypeObject *NPY_UNUSED(self), PyObject *args)
243+
{
244+
if (PyTuple_GET_SIZE(args) != 1 ||
245+
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
246+
PyErr_BadInternalCall();
247+
return NULL;
248+
}
249+
250+
long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));
251+
252+
if (version != PICKLE_VERSION) {
253+
PyErr_Format(PyExc_ValueError,
254+
"Pickle version mismatch. Got version %d but expected "
255+
"version %d.",
256+
version, PICKLE_VERSION);
257+
return NULL;
258+
}
259+
260+
Py_RETURN_NONE;
261+
}
262+
263+
static PyMethodDef StringDType_methods[] = {
264+
{
265+
"__reduce__",
266+
(PyCFunction)stringdtype__reduce__,
267+
METH_NOARGS,
268+
"Reduction method for an StringDType object",
269+
},
270+
{
271+
"__setstate__",
272+
(PyCFunction)stringdtype__setstate__,
273+
METH_O,
274+
"Unpickle an StringDType object",
275+
},
276+
{NULL},
277+
};
278+
203279
/*
204280
* This is the basic things that you need to create a Python Type/Class in C.
205281
* However, there is a slight difference here because we create a
@@ -215,6 +291,7 @@ PyArray_DTypeMeta StringDType = {
215291
.tp_dealloc = (destructor)stringdtype_dealloc,
216292
.tp_repr = (reprfunc)stringdtype_repr,
217293
.tp_str = (reprfunc)stringdtype_repr,
294+
.tp_methods = StringDType_methods,
218295
}},
219296
/* rest, filled in during DTypeMeta initialization */
220297
};

stringdtype/tests/test_stringdtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import os
2+
import pickle
3+
import tempfile
4+
15
import numpy as np
26
import pytest
37

@@ -124,3 +128,17 @@ def test_memory_usage(string_list):
124128
_memory_usage("hello")
125129
with pytest.raises(TypeError):
126130
_memory_usage(np.array([1, 2, 3]))
131+
132+
133+
def test_pickle_dtype():
134+
dtype = StringDType()
135+
136+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
137+
pickle.dump(dtype, f)
138+
139+
with open(f.name, "rb") as f:
140+
load_dtype = pickle.load(f)
141+
142+
assert dtype == load_dtype
143+
144+
os.remove(f.name)

0 commit comments

Comments
 (0)