Skip to content

Commit ae73fb0

Browse files
authored
Merge pull request #36 from ngoldbaum/memory-usage
add a helper to get memory usage for string arrays
2 parents ae04d0c + 67bb449 commit ae73fb0

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

stringdtype/stringdtype/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
"""
66

77
from .scalar import StringScalar # isort: skip
8-
from ._main import StringDType
8+
from ._main import StringDType, _memory_usage
99

10-
__all__ = ["StringDType", "StringScalar"]
10+
__all__ = [
11+
"StringDType",
12+
"StringScalar",
13+
"_memory_usage",
14+
]

stringdtype/stringdtype/src/dtype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,7 @@ new_stringdtype_instance(void);
2626
int
2727
init_string_dtype(void);
2828

29+
// from dtypemeta.h, not public in numpy
30+
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
31+
2932
#endif /*_NPY_DTYPE_H*/

stringdtype/stringdtype/src/main.c

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,81 @@
66
#include "numpy/experimental_dtype_api.h"
77

88
#include "dtype.h"
9+
#include "static_string.h"
910
#include "umath.h"
1011

12+
static PyObject *
13+
_memory_usage(PyObject *NPY_UNUSED(self), PyObject *obj)
14+
{
15+
if (!PyArray_Check(obj)) {
16+
PyErr_SetString(PyExc_TypeError,
17+
"can only be called with ndarray object");
18+
return NULL;
19+
}
20+
21+
PyArrayObject *arr = (PyArrayObject *)obj;
22+
23+
PyArray_Descr *descr = PyArray_DESCR(arr);
24+
PyArray_DTypeMeta *dtype = NPY_DTYPE(descr);
25+
26+
if (dtype != &StringDType) {
27+
PyErr_SetString(PyExc_TypeError,
28+
"can only be called with a StringDType array");
29+
return NULL;
30+
}
31+
32+
NpyIter *iter =
33+
NpyIter_New(arr, NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP,
34+
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
35+
36+
if (iter == NULL) {
37+
return NULL;
38+
}
39+
40+
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);
41+
42+
if (iternext == NULL) {
43+
NpyIter_Deallocate(iter);
44+
return NULL;
45+
}
46+
47+
char **dataptr = NpyIter_GetDataPtrArray(iter);
48+
npy_intp *strideptr = NpyIter_GetInnerStrideArray(iter);
49+
npy_intp *innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
50+
51+
// initialize with the size of the internal buffer
52+
size_t memory_usage = PyArray_NBYTES(arr);
53+
size_t struct_size = sizeof(ss);
54+
55+
do {
56+
ss **in = (ss **)*dataptr;
57+
npy_intp stride = *strideptr / descr->elsize;
58+
npy_intp count = *innersizeptr;
59+
60+
while (count--) {
61+
// +1 byte for the null terminator
62+
memory_usage += (*in)->len + struct_size + 1;
63+
in += stride;
64+
}
65+
66+
} while (iternext(iter));
67+
68+
PyObject *ret = PyLong_FromSize_t(memory_usage);
69+
70+
return ret;
71+
}
72+
73+
static PyMethodDef string_methods[] = {
74+
{"_memory_usage", _memory_usage, METH_O,
75+
"get memory usage for an array"},
76+
{NULL},
77+
};
78+
1179
static struct PyModuleDef moduledef = {
1280
PyModuleDef_HEAD_INIT,
1381
.m_name = "stringdtype_main",
1482
.m_size = -1,
83+
.m_methods = string_methods,
1584
};
1685

1786
/* Module initialization function */

stringdtype/tests/test_stringdtype.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from stringdtype import StringDType, StringScalar
4+
from stringdtype import StringDType, StringScalar, _memory_usage
55

66

77
@pytest.fixture
@@ -111,3 +111,16 @@ def test_isnan(string_list):
111111
np.testing.assert_array_equal(
112112
np.isnan(sarr), np.zeros_like(sarr, dtype=np.bool_)
113113
)
114+
115+
116+
def test_memory_usage(string_list):
117+
sarr = np.array(string_list, dtype=StringDType())
118+
# 4 bytes for each ASCII string buffer in string_list
119+
# (three characters and null terminator)
120+
# plus enough bytes for the size_t length
121+
# plus enough bytes for the pointer in the array buffer
122+
assert _memory_usage(sarr) == (4 + 2 * np.dtype(np.uintp).itemsize) * 3
123+
with pytest.raises(TypeError):
124+
_memory_usage("hello")
125+
with pytest.raises(TypeError):
126+
_memory_usage(np.array([1, 2, 3]))

0 commit comments

Comments
 (0)