Skip to content

Commit 5c96050

Browse files
committed
add support for pickling ASCIIDType arrays
1 parent ae73fb0 commit 5c96050

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

asciidtype/asciidtype/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,14 @@
77
from .scalar import ASCIIScalar # isort: skip
88
from ._asciidtype_main import ASCIIDType
99

10-
__all__ = ["ASCIIDType", "ASCIIScalar"]
10+
11+
def _reconstruct_ASCIIDType(*args):
12+
# this is needed for pickling instances because numpy overrides the pickling
13+
# behavior of the DTypeMeta class using copyreg. By pickling a wrapper
14+
# around the ASCIIDType initializer, we avoid triggering the code in numpy
15+
# that tries to handle pickling DTypeMeta instances. See
16+
# https://github.com/numpy/numpy/issues/23135#issuecomment-1410967842
17+
return ASCIIDType(*args)
18+
19+
20+
__all__ = ["ASCIIDType", "ASCIIScalar", "_reconstruct_ASCIIDType"]

asciidtype/asciidtype/src/dtype.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,82 @@ static PyMemberDef ASCIIDType_members[] = {
226226
{NULL},
227227
};
228228

229+
static int PICKLE_VERSION = 1;
230+
231+
static PyObject *
232+
asciidtype__reduce__(ASCIIDTypeObject *self)
233+
{
234+
PyObject *ret, *mod, *obj, *state;
235+
236+
ret = PyTuple_New(3);
237+
if (ret == NULL) {
238+
return NULL;
239+
}
240+
241+
mod = PyImport_ImportModule("asciidtype");
242+
if (mod == NULL) {
243+
Py_DECREF(ret);
244+
return NULL;
245+
}
246+
247+
obj = PyObject_GetAttrString(mod, "_reconstruct_ASCIIDType");
248+
Py_DECREF(mod);
249+
if (obj == NULL) {
250+
Py_DECREF(ret);
251+
return NULL;
252+
}
253+
254+
PyTuple_SET_ITEM(ret, 0, obj);
255+
256+
PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(l)", self->size));
257+
258+
state = PyTuple_New(1);
259+
260+
PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));
261+
262+
PyTuple_SET_ITEM(ret, 2, state);
263+
264+
return ret;
265+
}
266+
267+
static PyObject *
268+
asciidtype__setstate__(ASCIIDTypeObject *NPY_UNUSED(self), PyObject *args)
269+
{
270+
if (PyTuple_GET_SIZE(args) != 1 ||
271+
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
272+
PyErr_BadInternalCall();
273+
return NULL;
274+
}
275+
276+
long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));
277+
278+
if (version != PICKLE_VERSION) {
279+
PyErr_Format(PyExc_ValueError,
280+
"Pickle version mismatch. Got version %d but expected "
281+
"version %d.",
282+
version, PICKLE_VERSION);
283+
return NULL;
284+
}
285+
286+
Py_RETURN_NONE;
287+
}
288+
289+
static PyMethodDef ASCIIDType_methods[] = {
290+
{
291+
"__reduce__",
292+
(PyCFunction)asciidtype__reduce__,
293+
METH_NOARGS,
294+
"Reduction method for an ASCIIDType object",
295+
},
296+
{
297+
"__setstate__",
298+
(PyCFunction)asciidtype__setstate__,
299+
METH_O,
300+
"Unpickle an ASCIIDType object",
301+
},
302+
{NULL},
303+
};
304+
229305
/*
230306
* This is the basic things that you need to create a Python Type/Class in C.
231307
* However, there is a slight difference here because we create a
@@ -242,6 +318,7 @@ PyArray_DTypeMeta ASCIIDType = {
242318
.tp_repr = (reprfunc)asciidtype_repr,
243319
.tp_str = (reprfunc)asciidtype_repr,
244320
.tp_members = ASCIIDType_members,
321+
.tp_methods = ASCIIDType_methods,
245322
}},
246323
/* rest, filled in during DTypeMeta initialization */
247324
};

asciidtype/tests/test_asciidtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import os
2+
import pickle
13
import re
4+
import tempfile
25

36
import numpy as np
47
import pytest
@@ -230,3 +233,18 @@ def test_insert_scalar_directly():
230233
val = arr[0]
231234
arr[1] = val
232235
np.testing.assert_array_equal(arr, np.array(["some", "some"], dtype=dtype))
236+
237+
238+
def test_pickle():
239+
dtype = ASCIIDType(6)
240+
arr = np.array(["this", "is", "an", "array"], dtype=dtype)
241+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
242+
pickle.dump([arr, dtype], f)
243+
244+
with open(f.name, "rb") as f:
245+
res = pickle.load(f)
246+
247+
np.testing.assert_array_equal(arr, res[0])
248+
assert res[1] == dtype
249+
250+
os.remove(f.name)

0 commit comments

Comments
 (0)