@@ -11,7 +11,7 @@ PyObject *NA_OBJ = NULL;
11
11
* Internal helper to create new instances
12
12
*/
13
13
PyObject *
14
- new_stringdtype_instance (PyObject * na_object )
14
+ new_stringdtype_instance (PyObject * na_object , int coerce )
15
15
{
16
16
PyObject * new =
17
17
PyArrayDescr_Type .tp_new ((PyTypeObject * )& StringDType , NULL , NULL );
@@ -22,6 +22,7 @@ new_stringdtype_instance(PyObject *na_object)
22
22
23
23
Py_INCREF (na_object );
24
24
((StringDTypeObject * )new )-> na_object = na_object ;
25
+ ((StringDTypeObject * )new )-> coerce = coerce ;
25
26
26
27
PyArray_Descr * base = (PyArray_Descr * )new ;
27
28
base -> elsize = sizeof (ss );
@@ -67,23 +68,32 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
67
68
}
68
69
69
70
// returns a new reference to the string "value" of
70
- // `scalar`. If scalar is not already a string, __str__
71
- // is called to convert it to a string. If the scalar
72
- // is the na_object for the dtype class, return
73
- // a new reference to the na_object.
71
+ // `scalar`. If scalar is not already a string and
72
+ // coerce is nonzero, __str__ is called to convert it
73
+ // to a string. If coerce is zero, raises an error for
74
+ // non-string or non-NA input. If the scalar is the
75
+ // na_object for the dtype class, return a new
76
+ // reference to the na_object.
74
77
75
78
static PyObject *
76
- get_value (PyObject * scalar )
79
+ get_value (PyObject * scalar , int coerce )
77
80
{
78
81
PyTypeObject * scalar_type = Py_TYPE (scalar );
79
82
if (!((scalar_type == & PyUnicode_Type ) ||
80
83
(scalar_type == StringScalar_Type ))) {
81
- // attempt to coerce to str
82
- scalar = PyObject_Str (scalar );
83
- if (scalar == NULL ) {
84
- // __str__ raised an exception
84
+ if (coerce == 0 ) {
85
+ PyErr_SetString (PyExc_ValueError ,
86
+ "StringDType only allows string data" );
85
87
return NULL ;
86
88
}
89
+ else {
90
+ // attempt to coerce to str
91
+ scalar = PyObject_Str (scalar );
92
+ if (scalar == NULL ) {
93
+ // __str__ raised an exception
94
+ return NULL ;
95
+ }
96
+ }
87
97
}
88
98
// attempt to decode as UTF8
89
99
return PyUnicode_AsUTF8String (scalar );
@@ -93,12 +103,12 @@ static PyArray_Descr *
93
103
string_discover_descriptor_from_pyobject (PyTypeObject * NPY_UNUSED (cls ),
94
104
PyObject * obj )
95
105
{
96
- PyObject * val = get_value (obj );
106
+ PyObject * val = get_value (obj , 1 );
97
107
if (val == NULL ) {
98
108
return NULL ;
99
109
}
100
110
101
- PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance (NA_OBJ );
111
+ PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance (NA_OBJ , 1 );
102
112
if (ret == NULL ) {
103
113
return NULL ;
104
114
}
@@ -126,7 +136,7 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
126
136
// so it already contains a NA value
127
137
}
128
138
else {
129
- PyObject * val_obj = get_value (obj );
139
+ PyObject * val_obj = get_value (obj , descr -> coerce );
130
140
131
141
if (val_obj == NULL ) {
132
142
return -1 ;
@@ -334,21 +344,23 @@ static PyType_Slot StringDType_Slots[] = {
334
344
static PyObject *
335
345
stringdtype_new (PyTypeObject * NPY_UNUSED (cls ), PyObject * args , PyObject * kwds )
336
346
{
337
- static char * kwargs_strs [] = {"size" , "na_object" , NULL };
347
+ static char * kwargs_strs [] = {"size" , "na_object" , "coerce" , NULL };
338
348
339
349
long size = 0 ;
340
350
PyObject * na_object = NULL ;
351
+ int coerce = 1 ;
341
352
342
- if (!PyArg_ParseTupleAndKeywords (args , kwds , "|lO:StringDType" ,
343
- kwargs_strs , & size , & na_object )) {
353
+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "|lOp:StringDType" ,
354
+ kwargs_strs , & size , & na_object ,
355
+ & coerce )) {
344
356
return NULL ;
345
357
}
346
358
347
359
if (na_object == NULL ) {
348
360
na_object = NA_OBJ ;
349
361
}
350
362
351
- PyObject * ret = new_stringdtype_instance (na_object );
363
+ PyObject * ret = new_stringdtype_instance (na_object , coerce );
352
364
353
365
return ret ;
354
366
}
@@ -365,11 +377,18 @@ stringdtype_repr(StringDTypeObject *self)
365
377
PyObject * ret = NULL ;
366
378
// borrow reference
367
379
PyObject * na_object = self -> na_object ;
380
+ int coerce = self -> coerce ;
368
381
369
382
// TODO: handle non-default NA
370
- if (na_object != NA_OBJ ) {
371
- ret = PyUnicode_FromFormat ("StringDType(na_object=%R)" ,
372
- self -> na_object );
383
+ if (na_object != NA_OBJ && coerce == 0 ) {
384
+ ret = PyUnicode_FromFormat ("StringDType(na_object=%R, coerce=False)" ,
385
+ na_object );
386
+ }
387
+ else if (na_object != NA_OBJ ) {
388
+ ret = PyUnicode_FromFormat ("StringDType(na_object=%R)" , na_object );
389
+ }
390
+ else if (coerce == 0 ) {
391
+ ret = PyUnicode_FromFormat ("StringDType(coerce=False)" , coerce );
373
392
}
374
393
else {
375
394
ret = PyUnicode_FromString ("StringDType()" );
@@ -378,7 +397,7 @@ stringdtype_repr(StringDTypeObject *self)
378
397
return ret ;
379
398
}
380
399
381
- static int PICKLE_VERSION = 1 ;
400
+ static int PICKLE_VERSION = 2 ;
382
401
383
402
static PyObject *
384
403
stringdtype__reduce__ (StringDTypeObject * self )
@@ -405,9 +424,9 @@ stringdtype__reduce__(StringDTypeObject *self)
405
424
406
425
PyTuple_SET_ITEM (ret , 0 , obj );
407
426
408
- PyTuple_SET_ITEM (
409
- ret , 1 ,
410
- Py_BuildValue ( "(NO)" , PyLong_FromLong ( 0 ), self -> na_object ));
427
+ PyTuple_SET_ITEM (ret , 1 ,
428
+ Py_BuildValue ( "(NOi)" , PyLong_FromLong ( 0 ) ,
429
+ self -> na_object , self -> coerce ));
411
430
412
431
PyTuple_SET_ITEM (ret , 2 , Py_BuildValue ("(l)" , PICKLE_VERSION ));
413
432
@@ -456,9 +475,39 @@ static PyMemberDef StringDType_members[] = {
456
475
{"na_object" , T_OBJECT_EX , offsetof(StringDTypeObject , na_object ),
457
476
READONLY ,
458
477
"The missing value object associated with the dtype instance" },
478
+ {"coerce" , T_INT , offsetof(StringDTypeObject , coerce ), READONLY ,
479
+ "Controls hether non-string values should be coerced to string" },
459
480
{NULL , 0 , 0 , 0 , NULL },
460
481
};
461
482
483
+ static PyObject *
484
+ StringDType_richcompare (PyObject * self , PyObject * other , int op )
485
+ {
486
+ if (!((op == Py_EQ ) || (op == Py_NE )) ||
487
+ (Py_TYPE (other ) != Py_TYPE (self ))) {
488
+ Py_INCREF (Py_NotImplemented );
489
+ return Py_NotImplemented ;
490
+ }
491
+
492
+ // we know both are instances of StringDType so this is safe
493
+ StringDTypeObject * sself = (StringDTypeObject * )self ;
494
+ StringDTypeObject * sother = (StringDTypeObject * )other ;
495
+
496
+ int eq = (sself -> na_object == sother -> na_object ) &&
497
+ (sself -> coerce == sother -> coerce );
498
+
499
+ PyObject * ret = Py_NotImplemented ;
500
+ if ((op == Py_EQ && eq ) || (op == Py_NE && !eq )) {
501
+ ret = Py_True ;
502
+ }
503
+ else {
504
+ ret = Py_False ;
505
+ }
506
+
507
+ Py_INCREF (ret );
508
+ return ret ;
509
+ }
510
+
462
511
/*
463
512
* This is the basic things that you need to create a Python Type/Class in C.
464
513
* However, there is a slight difference here because we create a
@@ -476,6 +525,7 @@ StringDType_type StringDType = {
476
525
.tp_str = (reprfunc )stringdtype_repr ,
477
526
.tp_methods = StringDType_methods ,
478
527
.tp_members = StringDType_members ,
528
+ .tp_richcompare = StringDType_richcompare ,
479
529
}}},
480
530
/* rest, filled in during DTypeMeta initialization */
481
531
};
0 commit comments