Skip to content

Commit 30d7812

Browse files
authored
Merge pull request #81 from ngoldbaum/no-coerce
Optionally do not coerce non-strings to string in setitem
2 parents b387b2a + 0c4926b commit 30d7812

File tree

6 files changed

+143
-49
lines changed

6 files changed

+143
-49
lines changed

stringdtype/stringdtype/missing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
class NAType:
1+
class Singleton:
2+
_instance = None
3+
4+
def __new__(cls):
5+
if cls._instance is None:
6+
cls._instance = super().__new__(cls)
7+
return cls._instance
8+
9+
10+
class NAType(Singleton):
211
def __repr__(self):
312
return "stringdtype.NA"
413

stringdtype/stringdtype/src/casts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ gil_error(PyObject *type, const char *msg)
2121
{ \
2222
if (given_descrs[1] == NULL) { \
2323
PyArray_Descr *new = \
24-
(PyArray_Descr *)new_stringdtype_instance(NA_OBJ); \
24+
(PyArray_Descr *)new_stringdtype_instance(NA_OBJ, 0); \
2525
if (new == NULL) { \
2626
return (NPY_CASTING)-1; \
2727
} \

stringdtype/stringdtype/src/dtype.c

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ PyObject *NA_OBJ = NULL;
1111
* Internal helper to create new instances
1212
*/
1313
PyObject *
14-
new_stringdtype_instance(PyObject *na_object)
14+
new_stringdtype_instance(PyObject *na_object, int coerce)
1515
{
1616
PyObject *new =
1717
PyArrayDescr_Type.tp_new((PyTypeObject *)&StringDType, NULL, NULL);
@@ -22,6 +22,7 @@ new_stringdtype_instance(PyObject *na_object)
2222

2323
Py_INCREF(na_object);
2424
((StringDTypeObject *)new)->na_object = na_object;
25+
((StringDTypeObject *)new)->coerce = coerce;
2526

2627
PyArray_Descr *base = (PyArray_Descr *)new;
2728
base->elsize = sizeof(ss);
@@ -67,23 +68,32 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6768
}
6869

6970
// returns a new reference to the string "value" of
70-
// `scalar`. If scalar is not already a string, __str__
71-
// is called to convert it to a string. If the scalar
72-
// is the na_object for the dtype class, return
73-
// a new reference to the na_object.
71+
// `scalar`. If scalar is not already a string and
72+
// coerce is nonzero, __str__ is called to convert it
73+
// to a string. If coerce is zero, raises an error for
74+
// non-string or non-NA input. If the scalar is the
75+
// na_object for the dtype class, return a new
76+
// reference to the na_object.
7477

7578
static PyObject *
76-
get_value(PyObject *scalar)
79+
get_value(PyObject *scalar, int coerce)
7780
{
7881
PyTypeObject *scalar_type = Py_TYPE(scalar);
7982
if (!((scalar_type == &PyUnicode_Type) ||
8083
(scalar_type == StringScalar_Type))) {
81-
// attempt to coerce to str
82-
scalar = PyObject_Str(scalar);
83-
if (scalar == NULL) {
84-
// __str__ raised an exception
84+
if (coerce == 0) {
85+
PyErr_SetString(PyExc_ValueError,
86+
"StringDType only allows string data");
8587
return NULL;
8688
}
89+
else {
90+
// attempt to coerce to str
91+
scalar = PyObject_Str(scalar);
92+
if (scalar == NULL) {
93+
// __str__ raised an exception
94+
return NULL;
95+
}
96+
}
8797
}
8898
// attempt to decode as UTF8
8999
return PyUnicode_AsUTF8String(scalar);
@@ -93,12 +103,12 @@ static PyArray_Descr *
93103
string_discover_descriptor_from_pyobject(PyTypeObject *NPY_UNUSED(cls),
94104
PyObject *obj)
95105
{
96-
PyObject *val = get_value(obj);
106+
PyObject *val = get_value(obj, 1);
97107
if (val == NULL) {
98108
return NULL;
99109
}
100110

101-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ);
111+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ, 1);
102112
if (ret == NULL) {
103113
return NULL;
104114
}
@@ -126,7 +136,7 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
126136
// so it already contains a NA value
127137
}
128138
else {
129-
PyObject *val_obj = get_value(obj);
139+
PyObject *val_obj = get_value(obj, descr->coerce);
130140

131141
if (val_obj == NULL) {
132142
return -1;
@@ -334,21 +344,23 @@ static PyType_Slot StringDType_Slots[] = {
334344
static PyObject *
335345
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
336346
{
337-
static char *kwargs_strs[] = {"size", "na_object", NULL};
347+
static char *kwargs_strs[] = {"size", "na_object", "coerce", NULL};
338348

339349
long size = 0;
340350
PyObject *na_object = NULL;
351+
int coerce = 1;
341352

342-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lO:StringDType",
343-
kwargs_strs, &size, &na_object)) {
353+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lOp:StringDType",
354+
kwargs_strs, &size, &na_object,
355+
&coerce)) {
344356
return NULL;
345357
}
346358

347359
if (na_object == NULL) {
348360
na_object = NA_OBJ;
349361
}
350362

351-
PyObject *ret = new_stringdtype_instance(na_object);
363+
PyObject *ret = new_stringdtype_instance(na_object, coerce);
352364

353365
return ret;
354366
}
@@ -365,11 +377,18 @@ stringdtype_repr(StringDTypeObject *self)
365377
PyObject *ret = NULL;
366378
// borrow reference
367379
PyObject *na_object = self->na_object;
380+
int coerce = self->coerce;
368381

369382
// TODO: handle non-default NA
370-
if (na_object != NA_OBJ) {
371-
ret = PyUnicode_FromFormat("StringDType(na_object=%R)",
372-
self->na_object);
383+
if (na_object != NA_OBJ && coerce == 0) {
384+
ret = PyUnicode_FromFormat("StringDType(na_object=%R, coerce=False)",
385+
na_object);
386+
}
387+
else if (na_object != NA_OBJ) {
388+
ret = PyUnicode_FromFormat("StringDType(na_object=%R)", na_object);
389+
}
390+
else if (coerce == 0) {
391+
ret = PyUnicode_FromFormat("StringDType(coerce=False)", coerce);
373392
}
374393
else {
375394
ret = PyUnicode_FromString("StringDType()");
@@ -378,7 +397,7 @@ stringdtype_repr(StringDTypeObject *self)
378397
return ret;
379398
}
380399

381-
static int PICKLE_VERSION = 1;
400+
static int PICKLE_VERSION = 2;
382401

383402
static PyObject *
384403
stringdtype__reduce__(StringDTypeObject *self)
@@ -405,9 +424,9 @@ stringdtype__reduce__(StringDTypeObject *self)
405424

406425
PyTuple_SET_ITEM(ret, 0, obj);
407426

408-
PyTuple_SET_ITEM(
409-
ret, 1,
410-
Py_BuildValue("(NO)", PyLong_FromLong(0), self->na_object));
427+
PyTuple_SET_ITEM(ret, 1,
428+
Py_BuildValue("(NOi)", PyLong_FromLong(0),
429+
self->na_object, self->coerce));
411430

412431
PyTuple_SET_ITEM(ret, 2, Py_BuildValue("(l)", PICKLE_VERSION));
413432

@@ -456,9 +475,39 @@ static PyMemberDef StringDType_members[] = {
456475
{"na_object", T_OBJECT_EX, offsetof(StringDTypeObject, na_object),
457476
READONLY,
458477
"The missing value object associated with the dtype instance"},
478+
{"coerce", T_INT, offsetof(StringDTypeObject, coerce), READONLY,
479+
"Controls hether non-string values should be coerced to string"},
459480
{NULL, 0, 0, 0, NULL},
460481
};
461482

483+
static PyObject *
484+
StringDType_richcompare(PyObject *self, PyObject *other, int op)
485+
{
486+
if (!((op == Py_EQ) || (op == Py_NE)) ||
487+
(Py_TYPE(other) != Py_TYPE(self))) {
488+
Py_INCREF(Py_NotImplemented);
489+
return Py_NotImplemented;
490+
}
491+
492+
// we know both are instances of StringDType so this is safe
493+
StringDTypeObject *sself = (StringDTypeObject *)self;
494+
StringDTypeObject *sother = (StringDTypeObject *)other;
495+
496+
int eq = (sself->na_object == sother->na_object) &&
497+
(sself->coerce == sother->coerce);
498+
499+
PyObject *ret = Py_NotImplemented;
500+
if ((op == Py_EQ && eq) || (op == Py_NE && !eq)) {
501+
ret = Py_True;
502+
}
503+
else {
504+
ret = Py_False;
505+
}
506+
507+
Py_INCREF(ret);
508+
return ret;
509+
}
510+
462511
/*
463512
* This is the basic things that you need to create a Python Type/Class in C.
464513
* However, there is a slight difference here because we create a
@@ -476,6 +525,7 @@ StringDType_type StringDType = {
476525
.tp_str = (reprfunc)stringdtype_repr,
477526
.tp_methods = StringDType_methods,
478527
.tp_members = StringDType_members,
528+
.tp_richcompare = StringDType_richcompare,
479529
}}},
480530
/* rest, filled in during DTypeMeta initialization */
481531
};

stringdtype/stringdtype/src/dtype.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,19 @@
2020
typedef struct {
2121
PyArray_Descr base;
2222
PyObject *na_object;
23+
int coerce;
2324
} StringDTypeObject;
2425

2526
typedef struct {
2627
PyArray_DTypeMeta base;
2728
} StringDType_type;
2829

2930
extern StringDType_type StringDType;
30-
extern StringDType_type PandasStringDType;
3131
extern PyTypeObject *StringScalar_Type;
32-
extern PyTypeObject *PandasStringScalar_Type;
3332
extern PyObject *NA_OBJ;
34-
extern int PANDAS_AVAILABLE;
3533

3634
PyObject *
37-
new_stringdtype_instance(PyObject *na_object);
35+
new_stringdtype_instance(PyObject *na_object, int coerce);
3836

3937
int
4038
init_string_dtype(void);

stringdtype/stringdtype/src/umath.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ multiply_resolve_descriptors(
1919
Py_INCREF(rdescr);
2020
loop_descrs[1] = rdescr;
2121

22-
PyArray_Descr *odescr = NULL;
22+
StringDTypeObject *odescr = NULL;
2323

2424
if (dtypes[0] == (PyArray_DTypeMeta *)&StringDType) {
25-
odescr = ldescr;
25+
odescr = (StringDTypeObject *)ldescr;
2626
}
2727
else {
28-
odescr = rdescr;
28+
odescr = (StringDTypeObject *)rdescr;
2929
}
3030

3131
loop_descrs[2] = (PyArray_Descr *)new_stringdtype_instance(
32-
((StringDTypeObject *)odescr)->na_object);
32+
odescr->na_object, odescr->coerce);
3333

3434
return NPY_NO_CASTING;
3535
}

stringdtype/tests/test_stringdtype.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,35 @@ def string_list():
2626
)
2727

2828

29+
@pytest.fixture(params=[True, False])
30+
def coerce(request):
31+
return request.param
32+
33+
2934
@pytest.fixture(
3035
params=[None, NA, pd_param], ids=["None", "stringdtype.NA", "pandas.NA"]
3136
)
32-
def dtype(request):
33-
return StringDType(na_object=request.param)
37+
def na_object(request):
38+
return request.param
39+
40+
41+
@pytest.fixture()
42+
def dtype(na_object, coerce):
43+
return StringDType(na_object=na_object, coerce=coerce)
44+
45+
46+
def test_dtype_creation():
47+
dt = StringDType()
48+
assert dt.na_object is NA and dt.coerce == 1
49+
50+
dt = StringDType(na_object=None)
51+
assert dt.na_object is None and dt.coerce == 1
52+
53+
dt = StringDType(coerce=False)
54+
assert dt.na_object is NA and dt.coerce == 0
55+
56+
dt = StringDType(na_object=None, coerce=False)
57+
assert dt.na_object is None and dt.coerce == 0
3458

3559

3660
def test_scalar_creation():
@@ -44,10 +68,17 @@ def test_dtype_equality(dtype):
4468

4569

4670
def test_dtype_repr(dtype):
47-
if dtype.na_object is NA:
71+
if dtype.na_object is NA and dtype.coerce == 1:
4872
assert repr(dtype) == "StringDType()"
49-
else:
73+
elif dtype.coerce == 1:
5074
assert repr(dtype) == f"StringDType(na_object={dtype.na_object})"
75+
elif dtype.na_object is NA:
76+
assert repr(dtype) == "StringDType(coerce=False)"
77+
else:
78+
assert (
79+
repr(dtype)
80+
== f"StringDType(na_object={dtype.na_object}, coerce=False)"
81+
)
5182

5283

5384
@pytest.mark.parametrize(
@@ -61,12 +92,17 @@ def test_dtype_repr(dtype):
6192
)
6293
def test_array_creation_utf8(dtype, data):
6394
arr = np.array(data, dtype=dtype)
64-
assert repr(arr) == f"array({str(data)}, dtype={dtype})"
95+
assert str(arr) == "[" + " ".join(["'" + str(d) + "'" for d in data]) + "]"
96+
assert arr.dtype == dtype
6597

6698

6799
def test_array_creation_scalars(string_list):
68100
arr = np.array([StringScalar(s) for s in string_list])
69-
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
101+
assert (
102+
str(arr)
103+
== "[" + " ".join(["'" + str(s) + "'" for s in string_list]) + "]"
104+
)
105+
assert arr.dtype == StringDType()
70106

71107

72108
@pytest.mark.parametrize(
@@ -77,11 +113,15 @@ def test_array_creation_scalars(string_list):
77113
[object, object, object],
78114
],
79115
)
80-
def test_scalars_string_conversion(data):
81-
np.testing.assert_array_equal(
82-
np.array(data, dtype=StringDType()),
83-
np.array([str(d) for d in data], dtype=StringDType()),
84-
)
116+
def test_scalars_string_conversion(data, dtype):
117+
if dtype.coerce != 0:
118+
np.testing.assert_array_equal(
119+
np.array(data, dtype=dtype),
120+
np.array([str(d) for d in data], dtype=dtype),
121+
)
122+
else:
123+
with pytest.raises(ValueError):
124+
np.array(data, dtype=dtype)
85125

86126

87127
@pytest.mark.parametrize(
@@ -516,10 +556,7 @@ def test_create_with_na(dtype):
516556
na_val = dtype.na_object
517557
string_list = ["hello", na_val, "world"]
518558
arr = np.array(string_list, dtype=dtype)
519-
assert (
520-
repr(arr)
521-
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
522-
)
559+
assert str(arr) == "[" + " ".join([repr(s) for s in string_list]) + "]"
523560
assert arr[1] is dtype.na_object
524561

525562

0 commit comments

Comments
 (0)