13
13
#include "umath.h"
14
14
15
15
static int
16
- ascii_add_strided_loop (PyArrayMethod_Context * context , char * const data [],
17
- npy_intp const dimensions [], npy_intp const strides [],
18
- NpyAuxData * NPY_UNUSED (auxdata ))
16
+ string_equal_strided_loop (PyArrayMethod_Context * context , char * const data [],
17
+ npy_intp const dimensions [],
18
+ npy_intp const strides [],
19
+ NpyAuxData * NPY_UNUSED (auxdata ))
19
20
{
20
- PyArray_Descr * * descrs = context -> descriptors ;
21
- long in1_size = ((ASCIIDTypeObject * )descrs [0 ])-> size ;
22
- long in2_size = ((ASCIIDTypeObject * )descrs [1 ])-> size ;
23
- long out_size = ((ASCIIDTypeObject * )descrs [2 ])-> size ;
24
-
25
- npy_intp N = dimensions [0 ];
26
- char * in1 = data [0 ], * in2 = data [1 ], * out = data [2 ];
27
- npy_intp in1_stride = strides [0 ], in2_stride = strides [1 ],
28
- out_stride = strides [2 ];
29
-
30
- while (N -- ) {
31
- size_t in1_len = strnlen (in1 , in1_size );
32
- size_t in2_len = strnlen (in2 , in2_size );
33
- strncpy (out , in1 , in1_len );
34
- strncpy (out + in1_len , in2 , in2_len );
35
- if (in1_len + in2_len < out_size ) {
36
- out [in1_len + in2_len ] = '\0' ;
37
- }
38
- in1 += in1_stride ;
39
- in2 += in2_stride ;
40
- out += out_stride ;
41
- }
42
-
43
- return 0 ;
44
- }
45
-
46
- static NPY_CASTING
47
- ascii_add_resolve_descriptors (PyObject * NPY_UNUSED (self ),
48
- PyArray_DTypeMeta * dtypes [],
49
- PyArray_Descr * given_descrs [],
50
- PyArray_Descr * loop_descrs [],
51
- npy_intp * NPY_UNUSED (view_offset ))
52
- {
53
- long op1_size = ((ASCIIDTypeObject * )given_descrs [0 ])-> size ;
54
- long op2_size = ((ASCIIDTypeObject * )given_descrs [1 ])-> size ;
55
- long out_size = op1_size + op2_size ;
56
-
57
- /* the input descriptors can be used as-is */
58
- Py_INCREF (given_descrs [0 ]);
59
- loop_descrs [0 ] = given_descrs [0 ];
60
- Py_INCREF (given_descrs [1 ]);
61
- loop_descrs [1 ] = given_descrs [1 ];
62
-
63
- /* create new DType instance to hold the output */
64
- loop_descrs [2 ] = (PyArray_Descr * )new_asciidtype_instance (out_size );
65
- if (loop_descrs [2 ] == NULL ) {
66
- return -1 ;
67
- }
68
-
69
- return NPY_SAFE_CASTING ;
70
- }
71
-
72
- int
73
- init_add_ufunc (PyObject * numpy )
74
- {
75
- PyObject * add = PyObject_GetAttrString (numpy , "add" );
76
- if (add == NULL ) {
77
- return -1 ;
78
- }
79
-
80
- /*
81
- * Initialize spec for addition
82
- */
83
- static PyArray_DTypeMeta * add_dtypes [3 ] = {& ASCIIDType , & ASCIIDType ,
84
- & ASCIIDType };
85
-
86
- static PyType_Slot add_slots [] = {
87
- {NPY_METH_resolve_descriptors , & ascii_add_resolve_descriptors },
88
- {NPY_METH_strided_loop , & ascii_add_strided_loop },
89
- {0 , NULL }};
90
-
91
- PyArrayMethod_Spec AddSpec = {
92
- .name = "ascii_add" ,
93
- .nin = 2 ,
94
- .nout = 1 ,
95
- .dtypes = add_dtypes ,
96
- .slots = add_slots ,
97
- .flags = 0 ,
98
- .casting = NPY_SAFE_CASTING ,
99
- };
100
-
101
- /* register ufunc */
102
- if (PyUFunc_AddLoopFromSpec (add , & AddSpec ) < 0 ) {
103
- Py_DECREF (add );
104
- return -1 ;
105
- }
106
- Py_DECREF (add );
107
- return 0 ;
108
- }
109
-
110
- static int
111
- ascii_equal_strided_loop (PyArrayMethod_Context * context , char * const data [],
112
- npy_intp const dimensions [], npy_intp const strides [],
113
- NpyAuxData * NPY_UNUSED (auxdata ))
114
- {
115
- PyArray_Descr * * descrs = context -> descriptors ;
116
- long in1_size = ((ASCIIDTypeObject * )descrs [0 ])-> size ;
117
- long in2_size = ((ASCIIDTypeObject * )descrs [1 ])-> size ;
118
-
119
21
npy_intp N = dimensions [0 ];
120
- char * in1 = data [0 ], * in2 = data [1 ];
22
+ char * * in1 = (char * * )data [0 ];
23
+ char * * in2 = (char * * )data [1 ];
121
24
npy_bool * out = (npy_bool * )data [2 ];
122
- npy_intp in1_stride = strides [0 ], in2_stride = strides [1 ],
123
- out_stride = strides [2 ];
25
+ // strides are in bytes but pointer offsets are in pointer widths, so
26
+ // divide by the element size (one pointer width) to get the pointer offset
27
+ npy_intp in1_stride = strides [0 ] / context -> descriptors [0 ]-> elsize ;
28
+ npy_intp in2_stride = strides [1 ] / context -> descriptors [1 ]-> elsize ;
29
+ npy_intp out_stride = strides [2 ];
124
30
125
31
while (N -- ) {
126
- * out = (npy_bool )1 ;
127
- char * _in1 = in1 ;
128
- char * _in2 = in2 ;
129
- npy_bool * _out = out ;
130
- in1 += in1_stride ;
131
- in2 += in2_stride ;
132
- out += out_stride ;
133
- if (in1_size > in2_size ) {
134
- if (_in1 [in2_size ] != '\0' ) {
135
- * _out = (npy_bool )0 ;
136
- continue ;
137
- }
138
- if (strncmp (_in1 , _in2 , in2_size ) != 0 ) {
139
- * _out = (npy_bool )0 ;
140
- }
32
+ if (strcmp (* in1 , * in2 ) == 0 ) {
33
+ * out = (npy_bool )1 ;
141
34
}
142
35
else {
143
- if (in2_size > in1_size ) {
144
- if (_in2 [in1_size ] != '\0' ) {
145
- * _out = (npy_bool )0 ;
146
- continue ;
147
- }
148
- }
149
- if (strncmp (_in1 , _in2 , in1_size ) != 0 ) {
150
- * _out = (npy_bool )0 ;
151
- }
36
+ * out = (npy_bool )0 ;
152
37
}
38
+ in1 += in1_stride ;
39
+ in2 += in2_stride ;
40
+ out += out_stride ;
153
41
}
154
42
155
43
return 0 ;
156
44
}
157
45
158
46
static NPY_CASTING
159
- ascii_equal_resolve_descriptors (PyObject * NPY_UNUSED (self ),
160
- PyArray_DTypeMeta * dtypes [],
161
- PyArray_Descr * given_descrs [],
162
- PyArray_Descr * loop_descrs [],
163
- npy_intp * NPY_UNUSED (view_offset ))
47
+ string_equal_resolve_descriptors (PyObject * NPY_UNUSED (self ),
48
+ PyArray_DTypeMeta * dtypes [],
49
+ PyArray_Descr * given_descrs [],
50
+ PyArray_Descr * loop_descrs [],
51
+ npy_intp * NPY_UNUSED (view_offset ))
164
52
{
165
53
Py_INCREF (given_descrs [0 ]);
166
54
loop_descrs [0 ] = given_descrs [0 ];
@@ -172,7 +60,7 @@ ascii_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
172
60
return NPY_SAFE_CASTING ;
173
61
}
174
62
175
- static char * equal_name = "ascii_equal " ;
63
+ static char * equal_name = "string_equal " ;
176
64
177
65
int
178
66
init_equal_ufunc (PyObject * numpy )
@@ -186,13 +74,13 @@ init_equal_ufunc(PyObject *numpy)
186
74
* Initialize spec for equality
187
75
*/
188
76
PyArray_DTypeMeta * * eq_dtypes = malloc (3 * sizeof (PyArray_DTypeMeta * ));
189
- eq_dtypes [0 ] = & ASCIIDType ;
190
- eq_dtypes [1 ] = & ASCIIDType ;
77
+ eq_dtypes [0 ] = & StringDType ;
78
+ eq_dtypes [1 ] = & StringDType ;
191
79
eq_dtypes [2 ] = & PyArray_BoolDType ;
192
80
193
81
static PyType_Slot eq_slots [] = {
194
- {NPY_METH_resolve_descriptors , & ascii_equal_resolve_descriptors },
195
- {NPY_METH_strided_loop , & ascii_equal_strided_loop },
82
+ {NPY_METH_resolve_descriptors , & string_equal_resolve_descriptors },
83
+ {NPY_METH_strided_loop , & string_equal_strided_loop },
196
84
{0 , NULL }};
197
85
198
86
PyArrayMethod_Spec * EqualSpec = malloc (sizeof (PyArrayMethod_Spec ));
@@ -226,10 +114,6 @@ init_ufuncs(void)
226
114
return -1 ;
227
115
}
228
116
229
- if (init_add_ufunc (numpy ) < 0 ) {
230
- goto error ;
231
- }
232
-
233
117
if (init_equal_ufunc (numpy ) < 0 ) {
234
118
goto error ;
235
119
}
0 commit comments