Skip to content

Commit 0dd1dfa

Browse files
committed
add error handling for string allocation failures
1 parent 7499253 commit 0dd1dfa

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
#include "dtype.h"
44
#include "static_string.h"
55

6+
void
7+
gil_error(PyObject *type, const char *msg)
8+
{
9+
PyGILState_STATE gstate;
10+
gstate = PyGILState_Ensure();
11+
PyErr_SetString(type, msg);
12+
PyGILState_Release(gstate);
13+
}
14+
615
static NPY_CASTING
716
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
817
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -45,6 +54,10 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
4554

4655
while (N--) {
4756
out[0] = ssdup(in[0]);
57+
if (out[0] == NULL) {
58+
gil_error(PyExc_MemoryError, "ssdup failed");
59+
return -1;
60+
}
4861
in += in_stride;
4962
out += out_stride;
5063
}
@@ -201,15 +214,13 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
201214
size_t num_codepoints = 0;
202215
if (utf8_size(in, max_in_size, &num_codepoints, &out_num_bytes) ==
203216
-1) {
204-
// invalid codepoint found so acquire GIL, set error, return
205-
PyGILState_STATE gstate;
206-
gstate = PyGILState_Ensure();
207-
PyErr_SetString(PyExc_TypeError,
208-
"Invalid unicode code point found");
209-
PyGILState_Release(gstate);
217+
gil_error(PyExc_TypeError, "Invalid unicode code point found");
210218
return -1;
211219
}
212220
ss *out_ss = ssnewempty(out_num_bytes);
221+
if (out_ss == NULL) {
222+
gil_error(PyExc_MemoryError, "ssnewempty failed");
223+
}
213224
char *out_buf = out_ss->buf;
214225
for (int i = 0; i < num_codepoints; i++) {
215226
// get code point

stringdtype/stringdtype/src/dtype.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
114114
return -1;
115115
}
116116

117-
*dataptr = (char *)ssnewlen(val, length);
117+
ss *str_val = ssnewlen(val, length);
118+
if (str_val == NULL) {
119+
PyErr_SetString(PyExc_MemoryError, "ssnewlen failed");
120+
return -1;
121+
}
122+
*dataptr = (char *)str_val;
118123
Py_DECREF(val_obj);
119124
return 0;
120125
}

0 commit comments

Comments
 (0)