Skip to content

Commit e9449e2

Browse files
authored
Merge pull request #30 from ngoldbaum/string-array
store string data in a struct along with length
2 parents f007c97 + 563c116 commit e9449e2

File tree

6 files changed

+125
-31
lines changed

6 files changed

+125
-31
lines changed

stringdtype/meson.build

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ srcs = [
2626
'stringdtype/src/casts.h',
2727
'stringdtype/src/dtype.c',
2828
'stringdtype/src/main.c',
29+
'stringdtype/src/static_string.c',
30+
'stringdtype/src/static_string.h',
2931
'stringdtype/src/umath.c',
3032
'stringdtype/src/umath.h',
3133
]

stringdtype/stringdtype/src/casts.c

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
#include "casts.h"
22

33
#include "dtype.h"
4+
#include "static_string.h"
5+
6+
void
7+
gil_error(PyObject *type, const char *msg)
8+
{
9+
PyGILState_STATE gstate;
10+
gstate = PyGILState_Ensure();
11+
PyErr_SetString(type, msg);
12+
PyGILState_Release(gstate);
13+
}
414

515
static NPY_CASTING
616
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -35,17 +45,19 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
3545
NpyAuxData *NPY_UNUSED(auxdata))
3646
{
3747
npy_intp N = dimensions[0];
38-
char **in = (char **)data[0];
39-
char **out = (char **)data[1];
48+
ss **in = (ss **)data[0];
49+
ss **out = (ss **)data[1];
4050
// strides are in bytes but pointer offsets are in pointer widths, so
4151
// divide by the element size (one pointer width) to get the pointer offset
4252
npy_intp in_stride = strides[0] / context->descriptors[0]->elsize;
4353
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
4454

4555
while (N--) {
46-
size_t length = strlen(*in);
47-
out[0] = (char *)malloc((sizeof(char) * length) + 1);
48-
strncpy(*out, *in, length + 1);
56+
out[0] = ssdup(in[0]);
57+
if (out[0] == NULL) {
58+
gil_error(PyExc_MemoryError, "ssdup failed");
59+
return -1;
60+
}
4961
in += in_stride;
5062
out += out_stride;
5163
}
@@ -189,7 +201,7 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
189201

190202
npy_intp N = dimensions[0];
191203
Py_UCS4 *in = (Py_UCS4 *)data[0];
192-
char **out = (char **)data[1];
204+
ss **out = (ss **)data[1];
193205

194206
// 4 bytes per UCS4 character
195207
npy_intp in_stride = strides[0] / 4;
@@ -202,16 +214,14 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
202214
size_t num_codepoints = 0;
203215
if (utf8_size(in, max_in_size, &num_codepoints, &out_num_bytes) ==
204216
-1) {
205-
// invalid codepoint found so acquire GIL, set error, return
206-
PyGILState_STATE gstate;
207-
gstate = PyGILState_Ensure();
208-
PyErr_SetString(PyExc_TypeError,
209-
"Invalid unicode code point found");
210-
PyGILState_Release(gstate);
217+
gil_error(PyExc_TypeError, "Invalid unicode code point found");
211218
return -1;
212219
}
213-
// one extra byte for null terminator
214-
char *out_buf = malloc((out_num_bytes + 1) * sizeof(char));
220+
ss *out_ss = ssnewempty(out_num_bytes);
221+
if (out_ss == NULL) {
222+
gil_error(PyExc_MemoryError, "ssnewempty failed");
223+
}
224+
char *out_buf = out_ss->buf;
215225
for (int i = 0; i < num_codepoints; i++) {
216226
// get code point
217227
Py_UCS4 code = in[i];
@@ -237,7 +247,7 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
237247
out_buf[out_num_bytes] = '\0';
238248

239249
// set out to the address of the beginning of the string
240-
out[0] = out_buf;
250+
out[0] = out_ss;
241251

242252
in += in_stride;
243253
out += out_stride;
@@ -318,7 +328,7 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
318328
NpyAuxData *NPY_UNUSED(auxdata))
319329
{
320330
npy_intp N = dimensions[0];
321-
char **in = (char **)data[0];
331+
ss **in = (ss **)data[0];
322332
Py_UCS4 *out = (Py_UCS4 *)data[1];
323333
// strides are in bytes but pointer offsets are in pointer widths, so
324334
// divide by the element size (one pointer width) to get the pointer offset
@@ -329,7 +339,9 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
329339
long max_out_size = (context->descriptors[1]->elsize) / 4;
330340

331341
while (N--) {
332-
unsigned char *this_string = (unsigned char *)*in;
342+
unsigned char *this_string = (unsigned char *)((*in)->buf);
343+
size_t n_bytes = (*in)->len;
344+
size_t tot_n_bytes = 0;
333345

334346
for (int i = 0; i < max_out_size; i++) {
335347
Py_UCS4 code;
@@ -340,16 +352,13 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
340352

341353
// move to next character
342354
this_string += num_bytes;
355+
tot_n_bytes += num_bytes;
343356

344357
// set output codepoint
345358
out[i] = code;
346359

347-
// check if this is the null terminator
348-
if (code == 0) {
349-
// fill all remaining characters (if any) with zero
350-
for (int j = i + 1; j < max_out_size; j++) {
351-
out[j] = 0;
352-
}
360+
// stop if we've exhausted the input string
361+
if (tot_n_bytes >= n_bytes) {
353362
break;
354363
}
355364
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "dtype.h"
22

33
#include "casts.h"
4+
#include "static_string.h"
45

56
PyTypeObject *StringScalar_Type = NULL;
67

@@ -15,8 +16,9 @@ new_stringdtype_instance(void)
1516
if (new == NULL) {
1617
return NULL;
1718
}
18-
new->base.elsize = sizeof(char *);
19-
new->base.alignment = _Alignof(char *);
19+
new->base.elsize = sizeof(ss *);
20+
new->base.alignment = _Alignof(ss *);
21+
new->base.flags |= NPY_NEEDS_INIT;
2022

2123
return new;
2224
}
@@ -113,16 +115,26 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
113115
return -1;
114116
}
115117

116-
*dataptr = malloc(sizeof(char) * length + 1);
117-
strncpy(*dataptr, val, length + 1);
118+
ss *str_val = ssnewlen(val, length);
119+
if (str_val == NULL) {
120+
PyErr_SetString(PyExc_MemoryError, "ssnewlen failed");
121+
return -1;
122+
}
123+
// the dtype instance has the NPY_NEEDS_INIT flag set,
124+
// so if *dataptr is NULL, that means we're initializing
125+
// the array and don't need to free an existing string
126+
if (*dataptr != NULL) {
127+
free((ss *)*dataptr);
128+
}
129+
*dataptr = (char *)str_val;
118130
Py_DECREF(val_obj);
119131
return 0;
120132
}
121133

122134
static PyObject *
123135
stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
124136
{
125-
PyObject *val_obj = PyUnicode_FromString(*dataptr);
137+
PyObject *val_obj = PyUnicode_FromString(((ss *)*dataptr)->buf);
126138

127139
if (val_obj == NULL) {
128140
return NULL;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include "static_string.h"
2+
3+
// allocates a new ss string of length len, filling with the contents of init
4+
ss *
5+
ssnewlen(const char *init, size_t len)
6+
{
7+
// one extra byte for null terminator
8+
ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1));
9+
10+
if (ret == NULL) {
11+
return NULL;
12+
}
13+
14+
ret->len = len;
15+
16+
if (len > 0) {
17+
memcpy(ret->buf, init, len);
18+
}
19+
20+
ret->buf[len] = '\0';
21+
22+
return ret;
23+
}
24+
25+
// returns a new heap-allocated copy of input string *s*
26+
ss *
27+
ssdup(const ss *s)
28+
{
29+
return ssnewlen(s->buf, s->len);
30+
}
31+
32+
// returns a new, empty string of length len
33+
// does not do any initialization, the caller must
34+
// initialize and null-terminate the string
35+
ss *
36+
ssnewempty(size_t len)
37+
{
38+
ss *ret = (ss *)malloc(sizeof(ss) + sizeof(char) * (len + 1));
39+
ret->len = len;
40+
return ret;
41+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef _NPY_STATIC_STRING_H
2+
#define _NPY_STATIC_STRING_H
3+
4+
#include "stdlib.h"
5+
#include "string.h"
6+
7+
typedef struct ss {
8+
size_t len;
9+
char buf[];
10+
} ss;
11+
12+
// allocates a new ss string of length len, filling with the contents of init
13+
ss *
14+
ssnewlen(const char *init, size_t len);
15+
16+
// returns a new heap-allocated copy of input string *s*
17+
ss *
18+
ssdup(const ss *s);
19+
20+
// returns a new, empty string of length len
21+
// does not do any initialization, the caller must
22+
// initialize and null-terminate the string
23+
ss *
24+
ssnewempty(size_t len);
25+
26+
#endif /*_NPY_STATIC_STRING_H */

stringdtype/stringdtype/src/umath.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "numpy/ufuncobject.h"
1010

1111
#include "dtype.h"
12+
#include "static_string.h"
1213
#include "string.h"
1314
#include "umath.h"
1415

@@ -19,8 +20,8 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
1920
NpyAuxData *NPY_UNUSED(auxdata))
2021
{
2122
npy_intp N = dimensions[0];
22-
char **in1 = (char **)data[0];
23-
char **in2 = (char **)data[1];
23+
ss **in1 = (ss **)data[0];
24+
ss **in2 = (ss **)data[1];
2425
npy_bool *out = (npy_bool *)data[2];
2526
// strides are in bytes but pointer offsets are in pointer widths, so
2627
// divide by the element size (one pointer width) to get the pointer offset
@@ -29,7 +30,10 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
2930
npy_intp out_stride = strides[2];
3031

3132
while (N--) {
32-
if (strcmp(*in1, *in2) == 0) {
33+
ss *s1 = *in1;
34+
ss *s2 = *in2;
35+
36+
if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) {
3337
*out = (npy_bool)1;
3438
}
3539
else {

0 commit comments

Comments
 (0)