Skip to content

Commit 81c4e15

Browse files
committed
add round-trip casts between StringDType and PandasStringDType
1 parent 048df43 commit 81c4e15

File tree

4 files changed

+59
-48
lines changed

4 files changed

+59
-48
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static PyType_Slot s2s_slots[] = {
7575
{0, NULL}};
7676

7777
static char *s2s_name = "cast_StringDType_to_StringDType";
78+
static char *p2p_name = "cast_PandasStringDType_to_PandasStringDType";
7879
static char *s2p_name = "cast_StringDType_to_PandasStringDType";
7980
static char *p2s_name = "cast_PandasStringDType_to_StringDType";
8081

@@ -478,40 +479,42 @@ get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
478479
}
479480

480481
PyArrayMethod_Spec **
481-
get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other,
482-
int pandas_available)
482+
get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
483483
{
484-
PyArray_DTypeMeta **s2s_dtypes = get_dtypes(this, this);
484+
char *t2t_name = NULL;
485485

486-
PyArrayMethod_Spec *StringToStringCastSpec =
487-
get_cast_spec(s2s_name, NPY_NO_CASTING,
488-
NPY_METH_SUPPORTS_UNALIGNED, s2s_dtypes, s2s_slots);
486+
if (this == (PyArray_DTypeMeta *)&StringDType) {
487+
t2t_name = s2s_name;
488+
}
489+
else {
490+
t2t_name = p2p_name;
491+
}
492+
493+
PyArray_DTypeMeta **t2t_dtypes = get_dtypes(this, this);
494+
495+
PyArrayMethod_Spec *ThisToThisCastSpec =
496+
get_cast_spec(t2t_name, NPY_NO_CASTING,
497+
NPY_METH_SUPPORTS_UNALIGNED, t2t_dtypes, s2s_slots);
489498

490499
PyArrayMethod_Spec *ThisToOtherCastSpec = NULL;
491500
PyArrayMethod_Spec *OtherToThisCastSpec = NULL;
492501

493-
if (pandas_available) {
494-
char *t2o_name = NULL;
495-
char *o2t_name = NULL;
502+
int is_pandas = (this == (PyArray_DTypeMeta *)&PandasStringDType);
496503

497-
if (this == (PyArray_DTypeMeta *)&StringDType) {
498-
t2o_name = s2p_name;
499-
o2t_name = p2s_name;
500-
}
501-
else {
502-
t2o_name = p2s_name;
503-
o2t_name = s2p_name;
504-
}
504+
int num_casts = 5;
505+
506+
if (is_pandas) {
507+
num_casts = 7;
505508

506509
PyArray_DTypeMeta **t2o_dtypes = get_dtypes(this, other);
507510

508-
ThisToOtherCastSpec = get_cast_spec(t2o_name, NPY_NO_CASTING,
511+
ThisToOtherCastSpec = get_cast_spec(p2s_name, NPY_NO_CASTING,
509512
NPY_METH_SUPPORTS_UNALIGNED,
510513
t2o_dtypes, s2s_slots);
511514

512515
PyArray_DTypeMeta **o2t_dtypes = get_dtypes(other, this);
513516

514-
OtherToThisCastSpec = get_cast_spec(o2t_name, NPY_NO_CASTING,
517+
OtherToThisCastSpec = get_cast_spec(s2p_name, NPY_NO_CASTING,
515518
NPY_METH_SUPPORTS_UNALIGNED,
516519
o2t_dtypes, s2s_slots);
517520
}
@@ -536,18 +539,13 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other,
536539

537540
PyArrayMethod_Spec **casts = NULL;
538541

539-
if (pandas_available) {
540-
casts = malloc(7 * sizeof(PyArrayMethod_Spec *));
541-
}
542-
else {
543-
casts = malloc(5 * sizeof(PyArrayMethod_Spec *));
544-
}
542+
casts = malloc(num_casts * sizeof(PyArrayMethod_Spec *));
545543

546-
casts[0] = StringToStringCastSpec;
544+
casts[0] = ThisToThisCastSpec;
547545
casts[1] = UnicodeToStringCastSpec;
548546
casts[2] = StringToUnicodeCastSpec;
549547
casts[3] = StringToBoolCastSpec;
550-
if (pandas_available) {
548+
if (is_pandas) {
551549
casts[4] = ThisToOtherCastSpec;
552550
casts[5] = OtherToThisCastSpec;
553551
casts[6] = NULL;

stringdtype/stringdtype/src/casts.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
#include "numpy/ndarraytypes.h"
1212

1313
PyArrayMethod_Spec **
14-
get_casts(PyArray_DTypeMeta *this_dtype, PyArray_DTypeMeta *other_dtype,
15-
int pandas_available);
14+
get_casts(PyArray_DTypeMeta *this_dtype, PyArray_DTypeMeta *other_dtype);
1615

1716
size_t
1817
utf8_char_to_ucs4_code(unsigned char *, Py_UCS4 *);

stringdtype/stringdtype/src/dtype.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,8 @@ init_string_dtype(void)
533533
PANDAS_AVAILABLE = 1;
534534
}
535535

536-
PyArrayMethod_Spec **StringDType_casts = get_casts(
537-
(PyArray_DTypeMeta *)&StringDType,
538-
(PyArray_DTypeMeta *)&PandasStringDType, PANDAS_AVAILABLE);
536+
PyArrayMethod_Spec **StringDType_casts =
537+
get_casts((PyArray_DTypeMeta *)&StringDType, NULL);
539538

540539
PyArrayDTypeMeta_Spec StringDType_DTypeSpec = {
541540
.typeobj = StringScalar_Type,
@@ -557,10 +556,6 @@ init_string_dtype(void)
557556
return -1;
558557
}
559558

560-
// Partially initialize PandasStringDType so cast setup succeeds
561-
((PyObject *)&PandasStringDType)->ob_type = &PyArrayDTypeMeta_Type;
562-
((PyTypeObject *)&PandasStringDType)->tp_base = &PyArrayDescr_Type;
563-
564559
if (PyArrayInitDTypeMeta_FromSpec((PyArray_DTypeMeta *)&StringDType,
565560
&StringDType_DTypeSpec) < 0) {
566561
return -1;
@@ -585,7 +580,7 @@ init_string_dtype(void)
585580
if (PANDAS_AVAILABLE) {
586581
PyArrayMethod_Spec **PandasStringDType_casts =
587582
get_casts((PyArray_DTypeMeta *)&PandasStringDType,
588-
(PyArray_DTypeMeta *)&StringDType, PANDAS_AVAILABLE);
583+
(PyArray_DTypeMeta *)&StringDType);
589584

590585
PyArrayDTypeMeta_Spec PandasStringDType_DTypeSpec = {
591586
.typeobj = PandasStringScalar_Type,
@@ -601,6 +596,8 @@ init_string_dtype(void)
601596
return -1;
602597
}
603598

599+
((PyObject *)&PandasStringDType)->ob_type = &PyArrayDTypeMeta_Type;
600+
((PyTypeObject *)&PandasStringDType)->tp_base = &PyArrayDescr_Type;
604601
((PyTypeObject *)&PandasStringDType)->tp_dict = PyDict_New();
605602
// C attribute for fast access
606603
Py_INCREF(pandas_na_obj);

stringdtype/tests/test_stringdtype.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,18 @@ def dtype(request):
3131
if request.param == "StringDType":
3232
return StringDType()
3333
elif request.param == "PandasStringDType":
34-
try:
35-
from stringdtype import PandasStringDType
34+
pytest.importorskip("pandas")
35+
from stringdtype import PandasStringDType
3636

37-
return PandasStringDType()
38-
except ImportError:
39-
pytest.skip("cannot import pandas")
37+
return PandasStringDType()
4038

4139

4240
@pytest.fixture
4341
def scalar(dtype):
4442
if dtype == StringDType():
4543
return StringScalar
46-
try:
47-
from stringdtype import PandasStringDType
48-
49-
del PandasStringDType
44+
else:
5045
return PandasStringScalar
51-
except ImportError:
52-
pytest.skip("cannot import pandas")
5346

5447

5548
def test_scalar_creation(scalar):
@@ -77,6 +70,8 @@ def test_array_creation_utf8(dtype, data):
7770

7871

7972
def test_array_creation_scalars(string_list, scalar, dtype):
73+
if not issubclass(scalar, dtype.type):
74+
pytest.skip()
8075
arr = np.array([scalar(s) for s in string_list])
8176
assert repr(arr) == repr(np.array(string_list, dtype=dtype))
8277

@@ -383,3 +378,25 @@ def test_create_with_na(dtype, na_val):
383378
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
384379
)
385380
assert arr[1] is dtype.na_object
381+
382+
383+
def test_pandas_to_numpy_cast(string_list):
384+
pytest.importorskip("pandas")
385+
386+
from stringdtype import PandasStringDType
387+
388+
sarr = np.array(string_list, dtype=StringDType())
389+
390+
parr = sarr.astype(PandasStringDType())
391+
392+
np.testing.assert_array_equal(
393+
parr, np.array(string_list, dtype=PandasStringDType())
394+
)
395+
np.testing.assert_array_equal(sarr, parr.astype(StringDType()))
396+
397+
# check that NA converts correctly too
398+
sarr[1] = StringDType.na_object
399+
parr = sarr.astype(PandasStringDType())
400+
401+
assert parr[1] is PandasStringDType.na_object
402+
assert parr.astype(StringDType())[1] is StringDType.na_object

0 commit comments

Comments
 (0)