Skip to content

Commit 9aa25d5

Browse files
committed
Add option to prevent coercing data to strings
1 parent 777f636 commit 9aa25d5

File tree

5 files changed

+136
-48
lines changed

5 files changed

+136
-48
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ gil_error(PyObject *type, const char *msg)
2121
{ \
2222
if (given_descrs[1] == NULL) { \
2323
PyArray_Descr *new = \
24-
(PyArray_Descr *)new_stringdtype_instance(NA_OBJ); \
24+
(PyArray_Descr *)new_stringdtype_instance(NA_OBJ, 0); \
2525
if (new == NULL) { \
2626
return (NPY_CASTING)-1; \
2727
} \

stringdtype/stringdtype/src/dtype.c

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ PyObject *NA_OBJ = NULL;
1111
* Internal helper to create new instances
1212
*/
1313
PyObject *
14-
new_stringdtype_instance(PyObject *na_object)
14+
new_stringdtype_instance(PyObject *na_object, int coerce)
1515
{
1616
PyObject *new =
1717
PyArrayDescr_Type.tp_new((PyTypeObject *)&StringDType, NULL, NULL);
@@ -22,6 +22,7 @@ new_stringdtype_instance(PyObject *na_object)
2222

2323
Py_INCREF(na_object);
2424
((StringDTypeObject *)new)->na_object = na_object;
25+
((StringDTypeObject *)new)->coerce = coerce;
2526

2627
PyArray_Descr *base = (PyArray_Descr *)new;
2728
base->elsize = sizeof(ss);
@@ -67,23 +68,32 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6768
}
6869

6970
// returns a new reference to the string "value" of
70-
// `scalar`. If scalar is not already a string, __str__
71-
// is called to convert it to a string. If the scalar
72-
// is the na_object for the dtype class, return
73-
// a new reference to the na_object.
71+
// `scalar`. If scalar is not already a string and
72+
// coerce is nonzero, __str__ is called to convert it
73+
// to a string. If coerce is zero, raises an error for
74+
// non-string or non-NA input. If the scalar is the
75+
// na_object for the dtype class, return a new
76+
// reference to the na_object.
7477

7578
static PyObject *
76-
get_value(PyObject *scalar)
79+
get_value(PyObject *scalar, int coerce)
7780
{
7881
PyTypeObject *scalar_type = Py_TYPE(scalar);
7982
if (!((scalar_type == &PyUnicode_Type) ||
8083
(scalar_type == StringScalar_Type))) {
81-
// attempt to coerce to str
82-
scalar = PyObject_Str(scalar);
83-
if (scalar == NULL) {
84-
// __str__ raised an exception
84+
if (coerce == 0) {
85+
PyErr_SetString(PyExc_ValueError,
86+
"StringDType only allows string data");
8587
return NULL;
8688
}
89+
else {
90+
// attempt to coerce to str
91+
scalar = PyObject_Str(scalar);
92+
if (scalar == NULL) {
93+
// __str__ raised an exception
94+
return NULL;
95+
}
96+
}
8797
}
8898
// attempt to decode as UTF8
8999
return PyUnicode_AsUTF8String(scalar);
@@ -93,12 +103,12 @@ static PyArray_Descr *
93103
string_discover_descriptor_from_pyobject(PyTypeObject *NPY_UNUSED(cls),
94104
PyObject *obj)
95105
{
96-
PyObject *val = get_value(obj);
106+
PyObject *val = get_value(obj, 1);
97107
if (val == NULL) {
98108
return NULL;
99109
}
100110

101-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ);
111+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ, 1);
102112
if (ret == NULL) {
103113
return NULL;
104114
}
@@ -126,7 +136,7 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
126136
// so it already contains a NA value
127137
}
128138
else {
129-
PyObject *val_obj = get_value(obj);
139+
PyObject *val_obj = get_value(obj, descr->coerce);
130140

131141
if (val_obj == NULL) {
132142
return -1;
@@ -334,21 +344,23 @@ static PyType_Slot StringDType_Slots[] = {
334344
static PyObject *
335345
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
336346
{
337-
static char *kwargs_strs[] = {"size", "na_object", NULL};
347+
static char *kwargs_strs[] = {"size", "na_object", "coerce", NULL};
338348

339349
long size = 0;
340350
PyObject *na_object = NULL;
351+
int coerce = 1;
341352

342-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lO:StringDType",
343-
kwargs_strs, &size, &na_object)) {
353+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lOp:StringDType",
354+
kwargs_strs, &size, &na_object,
355+
&coerce)) {
344356
return NULL;
345357
}
346358

347359
if (na_object == NULL) {
348360
na_object = NA_OBJ;
349361
}
350362

351-
PyObject *ret = new_stringdtype_instance(na_object);
363+
PyObject *ret = new_stringdtype_instance(na_object, coerce);
352364

353365
return ret;
354366
}
@@ -365,11 +377,18 @@ stringdtype_repr(StringDTypeObject *self)
365377
PyObject *ret = NULL;
366378
// borrow reference
367379
PyObject *na_object = self->na_object;
380+
int coerce = self->coerce;
368381

369382
// TODO: handle non-default NA
370-
if (na_object != NA_OBJ) {
371-
ret = PyUnicode_FromFormat("StringDType(na_object=%R)",
372-
self->na_object);
383+
if (na_object != NA_OBJ && coerce == 0) {
384+
ret = PyUnicode_FromFormat("StringDType(na_object=%R, coerce=False)",
385+
na_object);
386+
}
387+
else if (na_object != NA_OBJ) {
388+
ret = PyUnicode_FromFormat("StringDType(na_object=%R)", na_object);
389+
}
390+
else if (coerce == 0) {
391+
ret = PyUnicode_FromFormat("StringDType(coerce=False)", coerce);
373392
}
374393
else {
375394
ret = PyUnicode_FromString("StringDType()");
@@ -378,7 +397,7 @@ stringdtype_repr(StringDTypeObject *self)
378397
return ret;
379398
}
380399

381-
static int PICKLE_VERSION = 1;
400+
static int PICKLE_VERSION = 2;
382401

383402
static PyObject *
384403
stringdtype__reduce__(StringDTypeObject *self)
@@ -405,9 +424,9 @@ stringdtype__reduce__(StringDTypeObject *self)
405424

406425
PyTuple_SET_ITEM(ret, 0, obj);
407426

408-
PyTuple_SET_ITEM(
409-
ret, 1,
410-
Py_BuildValue("(NO)", PyLong_FromLong(0), self->na_object));
427+
PyTuple_SET_ITEM(ret, 1,
428+
Py_BuildValue("(NOi)", PyLong_FromLong(0),
429+
self->na_object, self->coerce));
411430

412431
PyTuple_SET_ITEM(ret, 2, Py_BuildValue("(l)", PICKLE_VERSION));
413432

@@ -456,9 +475,42 @@ static PyMemberDef StringDType_members[] = {
456475
{"na_object", T_OBJECT_EX, offsetof(StringDTypeObject, na_object),
457476
READONLY,
458477
"The missing value object associated with the dtype instance"},
478+
{"coerce", T_INT, offsetof(StringDTypeObject, coerce), READONLY,
479+
"Controls hether non-string values should be coerced to string"},
459480
{NULL, 0, 0, 0, NULL},
460481
};
461482

483+
static PyObject *
484+
StringDType_richcompare(PyObject *self, PyObject *other, int op)
485+
{
486+
PyTypeObject *stype = Py_TYPE(self);
487+
PyTypeObject *otype = Py_TYPE(other);
488+
489+
if (stype != otype) {
490+
return Py_NotImplemented;
491+
}
492+
493+
StringDTypeObject *sself = (StringDTypeObject *)self;
494+
StringDTypeObject *sother = (StringDTypeObject *)other;
495+
496+
int eq = (sself->na_object == sother->na_object) &&
497+
(sself->coerce == sother->coerce);
498+
if (op == Py_EQ) {
499+
if (eq) {
500+
return Py_True;
501+
}
502+
return Py_False;
503+
}
504+
if (op == Py_NE) {
505+
if (eq) {
506+
return Py_False;
507+
}
508+
return Py_True;
509+
}
510+
511+
return Py_NotImplemented;
512+
}
513+
462514
/*
463515
* This is the basic things that you need to create a Python Type/Class in C.
464516
* However, there is a slight difference here because we create a
@@ -476,6 +528,7 @@ StringDType_type StringDType = {
476528
.tp_str = (reprfunc)stringdtype_repr,
477529
.tp_methods = StringDType_methods,
478530
.tp_members = StringDType_members,
531+
.tp_richcompare = StringDType_richcompare,
479532
}}},
480533
/* rest, filled in during DTypeMeta initialization */
481534
};

stringdtype/stringdtype/src/dtype.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,19 @@
1717
typedef struct {
1818
PyArray_Descr base;
1919
PyObject *na_object;
20+
int coerce;
2021
} StringDTypeObject;
2122

2223
typedef struct {
2324
PyArray_DTypeMeta base;
2425
} StringDType_type;
2526

2627
extern StringDType_type StringDType;
27-
extern StringDType_type PandasStringDType;
2828
extern PyTypeObject *StringScalar_Type;
29-
extern PyTypeObject *PandasStringScalar_Type;
3029
extern PyObject *NA_OBJ;
31-
extern int PANDAS_AVAILABLE;
3230

3331
PyObject *
34-
new_stringdtype_instance(PyObject *na_object);
32+
new_stringdtype_instance(PyObject *na_object, int coerce);
3533

3634
int
3735
init_string_dtype(void);

stringdtype/stringdtype/src/umath.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ multiply_resolve_descriptors(
2626
Py_INCREF(rdescr);
2727
loop_descrs[1] = rdescr;
2828

29-
PyArray_Descr *odescr = NULL;
29+
StringDTypeObject *odescr = NULL;
3030

3131
if (dtypes[0] == (PyArray_DTypeMeta *)&StringDType) {
32-
odescr = ldescr;
32+
odescr = (StringDTypeObject *)ldescr;
3333
}
3434
else {
35-
odescr = rdescr;
35+
odescr = (StringDTypeObject *)rdescr;
3636
}
3737

3838
loop_descrs[2] = (PyArray_Descr *)new_stringdtype_instance(
39-
((StringDTypeObject *)odescr)->na_object);
39+
odescr->na_object, odescr->coerce);
4040

4141
return NPY_NO_CASTING;
4242
}

stringdtype/tests/test_stringdtype.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,35 @@ def string_list():
2626
)
2727

2828

29+
@pytest.fixture(params=[True, False])
30+
def coerce(request):
31+
return request.param
32+
33+
2934
@pytest.fixture(
3035
params=[None, NA, pd_param], ids=["None", "stringdtype.NA", "pandas.NA"]
3136
)
32-
def dtype(request):
33-
return StringDType(na_object=request.param)
37+
def na_object(request):
38+
return request.param
39+
40+
41+
@pytest.fixture()
42+
def dtype(na_object, coerce):
43+
return StringDType(na_object=na_object, coerce=coerce)
44+
45+
46+
def test_dtype_creation():
47+
dt = StringDType()
48+
assert dt.na_object is NA and dt.coerce == 1
49+
50+
dt = StringDType(na_object=None)
51+
assert dt.na_object is None and dt.coerce == 1
52+
53+
dt = StringDType(coerce=False)
54+
assert dt.na_object is NA and dt.coerce == 0
55+
56+
dt = StringDType(na_object=None, coerce=False)
57+
assert dt.na_object is None and dt.coerce == 0
3458

3559

3660
def test_scalar_creation():
@@ -44,10 +68,17 @@ def test_dtype_equality(dtype):
4468

4569

4670
def test_dtype_repr(dtype):
47-
if dtype.na_object is NA:
71+
if dtype.na_object is NA and dtype.coerce == 1:
4872
assert repr(dtype) == "StringDType()"
49-
else:
73+
elif dtype.coerce == 1:
5074
assert repr(dtype) == f"StringDType(na_object={dtype.na_object})"
75+
elif dtype.na_object is NA:
76+
assert repr(dtype) == "StringDType(coerce=False)"
77+
else:
78+
assert (
79+
repr(dtype)
80+
== f"StringDType(na_object={dtype.na_object}, coerce=False)"
81+
)
5182

5283

5384
@pytest.mark.parametrize(
@@ -61,12 +92,17 @@ def test_dtype_repr(dtype):
6192
)
6293
def test_array_creation_utf8(dtype, data):
6394
arr = np.array(data, dtype=dtype)
64-
assert repr(arr) == f"array({str(data)}, dtype={dtype})"
95+
assert str(arr) == "[" + " ".join(["'" + str(d) + "'" for d in data]) + "]"
96+
assert arr.dtype == dtype
6597

6698

6799
def test_array_creation_scalars(string_list):
68100
arr = np.array([StringScalar(s) for s in string_list])
69-
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
101+
assert (
102+
str(arr)
103+
== "[" + " ".join(["'" + str(s) + "'" for s in string_list]) + "]"
104+
)
105+
assert arr.dtype == StringDType()
70106

71107

72108
@pytest.mark.parametrize(
@@ -77,11 +113,15 @@ def test_array_creation_scalars(string_list):
77113
[object, object, object],
78114
],
79115
)
80-
def test_scalars_string_conversion(data):
81-
np.testing.assert_array_equal(
82-
np.array(data, dtype=StringDType()),
83-
np.array([str(d) for d in data], dtype=StringDType()),
84-
)
116+
def test_scalars_string_conversion(data, dtype):
117+
if dtype.coerce != 0:
118+
np.testing.assert_array_equal(
119+
np.array(data, dtype=dtype),
120+
np.array([str(d) for d in data], dtype=dtype),
121+
)
122+
else:
123+
with pytest.raises(ValueError):
124+
np.array(data, dtype=dtype)
85125

86126

87127
@pytest.mark.parametrize(
@@ -510,8 +550,5 @@ def test_create_with_na(dtype):
510550
na_val = dtype.na_object
511551
string_list = ["hello", na_val, "world"]
512552
arr = np.array(string_list, dtype=dtype)
513-
assert (
514-
repr(arr)
515-
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
516-
)
553+
assert str(arr) == "[" + " ".join([repr(s) for s in string_list]) + "]"
517554
assert arr[1] is dtype.na_object

0 commit comments

Comments
 (0)