Skip to content

Commit b9c9cac

Browse files
authored
Merge pull request #71 from ngoldbaum/simplify-get-value
simplify stringdtype get_value implementation
2 parents 1f4ab30 + f94522b commit b9c9cac

File tree

2 files changed

+25
-53
lines changed

2 files changed

+25
-53
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -63,52 +63,32 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6363
return (PyArray_DTypeMeta *)Py_NotImplemented;
6464
}
6565

66+
// returns a new reference to the string "value" of
67+
// `scalar`. If scalar is not already a string, __str__
68+
// is called to convert it to a string. If the scalar
69+
// is the na_object for the dtype class, return
70+
// a new reference to the na_object.
71+
6672
static PyObject *
6773
get_value(PyObject *scalar, StringDType_type *cls)
6874
{
69-
PyObject *na_object = cls->na_object;
70-
PyObject *ret = NULL;
7175
PyTypeObject *expected_scalar_type = cls->base.scalar_type;
7276
PyTypeObject *scalar_type = Py_TYPE(scalar);
73-
// FIXME: handle bytes too
74-
if ((scalar_type == &PyUnicode_Type) ||
75-
(scalar_type == expected_scalar_type)) {
76-
// attempt to decode as UTF8
77-
ret = PyUnicode_AsUTF8String(scalar);
78-
if (ret == NULL) {
79-
PyErr_SetString(
80-
PyExc_TypeError,
81-
"Can only store UTF8 text in a StringDType array.");
77+
if (scalar == cls->na_object) {
78+
Py_INCREF(scalar);
79+
return scalar;
80+
}
81+
else if (!((scalar_type == &PyUnicode_Type) ||
82+
(scalar_type == expected_scalar_type))) {
83+
// attempt to coerce to str
84+
scalar = PyObject_Str(scalar);
85+
if (scalar == NULL) {
86+
// __str__ raised an exception
8287
return NULL;
8388
}
8489
}
85-
else if (scalar == na_object) {
86-
ret = scalar;
87-
Py_INCREF(ret);
88-
}
89-
// store np.nan as NA
90-
else if (scalar_type == &PyFloat_Type) {
91-
double scalar_val = PyFloat_AsDouble(scalar);
92-
if ((scalar_val == -1.0) && PyErr_Occurred()) {
93-
return NULL;
94-
}
95-
if (npy_isnan(scalar_val)) {
96-
ret = na_object;
97-
Py_INCREF(ret);
98-
}
99-
else {
100-
PyErr_SetString(
101-
PyExc_TypeError,
102-
"Can only store UTF8 text in a StringDType array.");
103-
return NULL;
104-
}
105-
}
106-
else {
107-
PyErr_SetString(PyExc_TypeError,
108-
"Can only store String text in a StringDType array.");
109-
return NULL;
110-
}
111-
return ret;
90+
// attempt to decode as UTF8
91+
return PyUnicode_AsUTF8String(scalar);
11292
}
11393

11494
// For a given python object, this function returns a borrowed reference

stringdtype/tests/test_stringdtype.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pytest
1414

1515
from stringdtype import (
16-
NA,
1716
PandasStringScalar,
1817
StringDType,
1918
StringScalar,
@@ -80,14 +79,15 @@ def test_array_creation_scalars(string_list, scalar, dtype):
8079
"data",
8180
[
8281
[1, 2, 3],
83-
[None, None, None],
8482
[b"abc", b"def", b"ghi"],
8583
[object, object, object],
8684
],
8785
)
88-
def test_bad_scalars(data):
89-
with pytest.raises(TypeError):
90-
np.array(data, dtype=StringDType())
86+
def test_scalars_string_conversion(data):
87+
np.testing.assert_array_equal(
88+
np.array(data, dtype=StringDType()),
89+
np.array([str(d) for d in data], dtype=StringDType()),
90+
)
9191

9292

9393
@pytest.mark.parametrize(
@@ -403,16 +403,8 @@ def test_ufunc_add(dtype, string_list, other_strings):
403403
)
404404

405405

406-
@pytest.mark.parametrize(
407-
"na_val", [float("nan"), np.nan, NA, getattr(pandas, "NA", None)]
408-
)
409-
def test_create_with_na(dtype, na_val):
410-
if not hasattr(pandas, "NA") or (
411-
dtype == StringDType() and na_val is pandas.NA
412-
):
413-
return
414-
if dtype != StringDType and na_val is NA:
415-
return
406+
def test_create_with_na(dtype):
407+
na_val = dtype.na_object
416408
string_list = ["hello", na_val, "world"]
417409
arr = np.array(string_list, dtype=dtype)
418410
assert (

0 commit comments

Comments
 (0)