Skip to content

Commit e6257e9

Browse files
committed
fix comparisons of unequal length strings
1 parent f879654 commit e6257e9

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ string_greater_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
226226
// s1 or s2 is NA
227227
*out = (npy_bool)0;
228228
}
229-
else if (s1->len == s2->len &&
230-
strncmp(s1->buf, s2->buf, s1->len) > 0) {
229+
else if (strcmp(s1->buf, s2->buf) > 0) {
231230
*out = (npy_bool)1;
232231
}
233232
else {
@@ -266,8 +265,7 @@ string_greater_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
266265
// s1 or s2 is NA
267266
*out = (npy_bool)0;
268267
}
269-
else if (s1->len == s2->len &&
270-
strncmp(s1->buf, s2->buf, s1->len) >= 0) {
268+
else if (strcmp(s1->buf, s2->buf) >= 0) {
271269
*out = (npy_bool)1;
272270
}
273271
else {
@@ -305,8 +303,7 @@ string_less_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
305303
// s1 or s2 is NA
306304
*out = (npy_bool)0;
307305
}
308-
else if (s1->len == s2->len &&
309-
strncmp(s1->buf, s2->buf, s1->len) < 0) {
306+
else if (strcmp(s1->buf, s2->buf) < 0) {
310307
*out = (npy_bool)1;
311308
}
312309
else {
@@ -344,8 +341,7 @@ string_less_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
344341
// s1 or s2 is NA
345342
*out = (npy_bool)0;
346343
}
347-
else if (s1->len == s2->len &&
348-
strncmp(s1->buf, s2->buf, s1->len) <= 0) {
344+
else if (strcmp(s1->buf, s2->buf) <= 0) {
349345
*out = (npy_bool)1;
350346
}
351347
else {

stringdtype/tests/test_stringdtype.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,35 @@ def test_insert_scalar(dtype, scalar, string_list):
154154

155155
@pytest.mark.parametrize("op", comparison_operators)
156156
@pytest.mark.parametrize("o_dtype", [np.str_, object])
157-
def test_comparison(string_list, dtype, op, o_dtype):
157+
def test_comparisons(string_list, dtype, op, o_dtype):
158158
sarr = np.array(string_list, dtype=dtype)
159159
oarr = np.array(string_list, dtype=o_dtype)
160160

161161
# test that comparison operators work
162162
res = op(sarr, sarr)
163163
ores = op(oarr, oarr)
164-
# test that promotion on the operator works as well
164+
# test that promotion works as well
165165
orres = op(sarr, oarr)
166166
olres = op(oarr, sarr)
167167

168168
np.testing.assert_array_equal(res, ores)
169169
np.testing.assert_array_equal(res, orres)
170170
np.testing.assert_array_equal(res, olres)
171171

172+
# test we get the correct answer for unequal length strings
173+
sarr2 = np.array([s + "2" for s in string_list], dtype=dtype)
174+
oarr2 = np.array([s + "2" for s in string_list], dtype=o_dtype)
175+
176+
res = op(sarr, sarr2)
177+
ores = op(oarr, oarr2)
178+
179+
np.testing.assert_array_equal(res, ores)
180+
181+
res = op(sarr2, sarr)
182+
ores = op(oarr2, oarr)
183+
184+
np.testing.assert_array_equal(res, ores)
185+
172186

173187
def test_isnan(dtype, string_list):
174188
sarr = np.array(string_list + [dtype.na_object], dtype=dtype)

0 commit comments

Comments
 (0)