Skip to content

Commit 7c17d57

Browse files
committed
enable the equal ufunc
1 parent 310217f commit 7c17d57

File tree

3 files changed

+33
-149
lines changed

3 files changed

+33
-149
lines changed

stringdtype/meson.build

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ srcs = [
2626
'stringdtype/src/casts.h',
2727
'stringdtype/src/dtype.c',
2828
'stringdtype/src/main.c',
29-
# 'stringdtype/src/umath.c',
30-
# 'stringdtype/src/umath.h',
29+
'stringdtype/src/umath.c',
30+
'stringdtype/src/umath.h',
3131
]
3232

3333
py.install_sources(

stringdtype/stringdtype/src/main.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "numpy/experimental_dtype_api.h"
77

88
#include "dtype.h"
9-
// #include "umath.h"
9+
#include "umath.h"
1010

1111
static struct PyModuleDef moduledef = {
1212
PyModuleDef_HEAD_INIT,
@@ -50,9 +50,9 @@ PyInit__main(void)
5050
goto error;
5151
}
5252

53-
// if (init_ufuncs() < 0) {
54-
// goto error;
55-
// }
53+
if (init_ufuncs() < 0) {
54+
goto error;
55+
}
5656

5757
return m;
5858

stringdtype/stringdtype/src/umath.c

Lines changed: 27 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -13,154 +13,42 @@
1313
#include "umath.h"
1414

1515
static int
16-
ascii_add_strided_loop(PyArrayMethod_Context *context, char *const data[],
17-
npy_intp const dimensions[], npy_intp const strides[],
18-
NpyAuxData *NPY_UNUSED(auxdata))
16+
string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
17+
npy_intp const dimensions[],
18+
npy_intp const strides[],
19+
NpyAuxData *NPY_UNUSED(auxdata))
1920
{
20-
PyArray_Descr **descrs = context->descriptors;
21-
long in1_size = ((ASCIIDTypeObject *)descrs[0])->size;
22-
long in2_size = ((ASCIIDTypeObject *)descrs[1])->size;
23-
long out_size = ((ASCIIDTypeObject *)descrs[2])->size;
24-
25-
npy_intp N = dimensions[0];
26-
char *in1 = data[0], *in2 = data[1], *out = data[2];
27-
npy_intp in1_stride = strides[0], in2_stride = strides[1],
28-
out_stride = strides[2];
29-
30-
while (N--) {
31-
size_t in1_len = strnlen(in1, in1_size);
32-
size_t in2_len = strnlen(in2, in2_size);
33-
strncpy(out, in1, in1_len);
34-
strncpy(out + in1_len, in2, in2_len);
35-
if (in1_len + in2_len < out_size) {
36-
out[in1_len + in2_len] = '\0';
37-
}
38-
in1 += in1_stride;
39-
in2 += in2_stride;
40-
out += out_stride;
41-
}
42-
43-
return 0;
44-
}
45-
46-
static NPY_CASTING
47-
ascii_add_resolve_descriptors(PyObject *NPY_UNUSED(self),
48-
PyArray_DTypeMeta *dtypes[],
49-
PyArray_Descr *given_descrs[],
50-
PyArray_Descr *loop_descrs[],
51-
npy_intp *NPY_UNUSED(view_offset))
52-
{
53-
long op1_size = ((ASCIIDTypeObject *)given_descrs[0])->size;
54-
long op2_size = ((ASCIIDTypeObject *)given_descrs[1])->size;
55-
long out_size = op1_size + op2_size;
56-
57-
/* the input descriptors can be used as-is */
58-
Py_INCREF(given_descrs[0]);
59-
loop_descrs[0] = given_descrs[0];
60-
Py_INCREF(given_descrs[1]);
61-
loop_descrs[1] = given_descrs[1];
62-
63-
/* create new DType instance to hold the output */
64-
loop_descrs[2] = (PyArray_Descr *)new_asciidtype_instance(out_size);
65-
if (loop_descrs[2] == NULL) {
66-
return -1;
67-
}
68-
69-
return NPY_SAFE_CASTING;
70-
}
71-
72-
int
73-
init_add_ufunc(PyObject *numpy)
74-
{
75-
PyObject *add = PyObject_GetAttrString(numpy, "add");
76-
if (add == NULL) {
77-
return -1;
78-
}
79-
80-
/*
81-
* Initialize spec for addition
82-
*/
83-
static PyArray_DTypeMeta *add_dtypes[3] = {&ASCIIDType, &ASCIIDType,
84-
&ASCIIDType};
85-
86-
static PyType_Slot add_slots[] = {
87-
{NPY_METH_resolve_descriptors, &ascii_add_resolve_descriptors},
88-
{NPY_METH_strided_loop, &ascii_add_strided_loop},
89-
{0, NULL}};
90-
91-
PyArrayMethod_Spec AddSpec = {
92-
.name = "ascii_add",
93-
.nin = 2,
94-
.nout = 1,
95-
.dtypes = add_dtypes,
96-
.slots = add_slots,
97-
.flags = 0,
98-
.casting = NPY_SAFE_CASTING,
99-
};
100-
101-
/* register ufunc */
102-
if (PyUFunc_AddLoopFromSpec(add, &AddSpec) < 0) {
103-
Py_DECREF(add);
104-
return -1;
105-
}
106-
Py_DECREF(add);
107-
return 0;
108-
}
109-
110-
static int
111-
ascii_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
112-
npy_intp const dimensions[], npy_intp const strides[],
113-
NpyAuxData *NPY_UNUSED(auxdata))
114-
{
115-
PyArray_Descr **descrs = context->descriptors;
116-
long in1_size = ((ASCIIDTypeObject *)descrs[0])->size;
117-
long in2_size = ((ASCIIDTypeObject *)descrs[1])->size;
118-
11921
npy_intp N = dimensions[0];
120-
char *in1 = data[0], *in2 = data[1];
22+
char **in1 = (char **)data[0];
23+
char **in2 = (char **)data[1];
12124
npy_bool *out = (npy_bool *)data[2];
122-
npy_intp in1_stride = strides[0], in2_stride = strides[1],
123-
out_stride = strides[2];
25+
// strides are in bytes but pointer offsets are in pointer widths, so
26+
// divide by the element size (one pointer width) to get the pointer offset
27+
npy_intp in1_stride = strides[0] / context->descriptors[0]->elsize;
28+
npy_intp in2_stride = strides[1] / context->descriptors[1]->elsize;
29+
npy_intp out_stride = strides[2];
12430

12531
while (N--) {
126-
*out = (npy_bool)1;
127-
char *_in1 = in1;
128-
char *_in2 = in2;
129-
npy_bool *_out = out;
130-
in1 += in1_stride;
131-
in2 += in2_stride;
132-
out += out_stride;
133-
if (in1_size > in2_size) {
134-
if (_in1[in2_size] != '\0') {
135-
*_out = (npy_bool)0;
136-
continue;
137-
}
138-
if (strncmp(_in1, _in2, in2_size) != 0) {
139-
*_out = (npy_bool)0;
140-
}
32+
if (strcmp(*in1, *in2) == 0) {
33+
*out = (npy_bool)1;
14134
}
14235
else {
143-
if (in2_size > in1_size) {
144-
if (_in2[in1_size] != '\0') {
145-
*_out = (npy_bool)0;
146-
continue;
147-
}
148-
}
149-
if (strncmp(_in1, _in2, in1_size) != 0) {
150-
*_out = (npy_bool)0;
151-
}
36+
*out = (npy_bool)0;
15237
}
38+
in1 += in1_stride;
39+
in2 += in2_stride;
40+
out += out_stride;
15341
}
15442

15543
return 0;
15644
}
15745

15846
static NPY_CASTING
159-
ascii_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
160-
PyArray_DTypeMeta *dtypes[],
161-
PyArray_Descr *given_descrs[],
162-
PyArray_Descr *loop_descrs[],
163-
npy_intp *NPY_UNUSED(view_offset))
47+
string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
48+
PyArray_DTypeMeta *dtypes[],
49+
PyArray_Descr *given_descrs[],
50+
PyArray_Descr *loop_descrs[],
51+
npy_intp *NPY_UNUSED(view_offset))
16452
{
16553
Py_INCREF(given_descrs[0]);
16654
loop_descrs[0] = given_descrs[0];
@@ -172,7 +60,7 @@ ascii_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
17260
return NPY_SAFE_CASTING;
17361
}
17462

175-
static char *equal_name = "ascii_equal";
63+
static char *equal_name = "string_equal";
17664

17765
int
17866
init_equal_ufunc(PyObject *numpy)
@@ -186,13 +74,13 @@ init_equal_ufunc(PyObject *numpy)
18674
* Initialize spec for equality
18775
*/
18876
PyArray_DTypeMeta **eq_dtypes = malloc(3 * sizeof(PyArray_DTypeMeta *));
189-
eq_dtypes[0] = &ASCIIDType;
190-
eq_dtypes[1] = &ASCIIDType;
77+
eq_dtypes[0] = &StringDType;
78+
eq_dtypes[1] = &StringDType;
19179
eq_dtypes[2] = &PyArray_BoolDType;
19280

19381
static PyType_Slot eq_slots[] = {
194-
{NPY_METH_resolve_descriptors, &ascii_equal_resolve_descriptors},
195-
{NPY_METH_strided_loop, &ascii_equal_strided_loop},
82+
{NPY_METH_resolve_descriptors, &string_equal_resolve_descriptors},
83+
{NPY_METH_strided_loop, &string_equal_strided_loop},
19684
{0, NULL}};
19785

19886
PyArrayMethod_Spec *EqualSpec = malloc(sizeof(PyArrayMethod_Spec));
@@ -226,10 +114,6 @@ init_ufuncs(void)
226114
return -1;
227115
}
228116

229-
if (init_add_ufunc(numpy) < 0) {
230-
goto error;
231-
}
232-
233117
if (init_equal_ufunc(numpy) < 0) {
234118
goto error;
235119
}

0 commit comments

Comments
 (0)