Skip to content

Commit dc0d884

Browse files
authored
Merge pull request #68 from ngoldbaum/stringdtype-view-casts
Add casts between StringDType and PandasStringDType
2 parents 4f7de1d + 81c4e15 commit dc0d884

File tree

5 files changed

+115
-39
lines changed

5 files changed

+115
-39
lines changed

stringdtype/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ per-file-ignores = {"__init__.py" = ["F401"]}
3535

3636
[tool.meson-python.args]
3737
dist = []
38-
setup = ["-Ddebug=true", "-Doptimization=0"]
38+
setup = ["-Dbuildtype=debug"]
3939
compile = []
4040
install = []

stringdtype/stringdtype/src/casts.c

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ static PyType_Slot s2s_slots[] = {
7575
{0, NULL}};
7676

7777
static char *s2s_name = "cast_StringDType_to_StringDType";
78+
static char *p2p_name = "cast_PandasStringDType_to_PandasStringDType";
79+
static char *s2p_name = "cast_StringDType_to_PandasStringDType";
80+
static char *p2s_name = "cast_PandasStringDType_to_StringDType";
7881

7982
// unicode to string
8083

@@ -476,38 +479,80 @@ get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
476479
}
477480

478481
PyArrayMethod_Spec **
479-
get_casts(void)
482+
get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
480483
{
481-
PyArray_DTypeMeta **s2s_dtypes = get_dtypes(NULL, NULL);
484+
char *t2t_name = NULL;
482485

483-
PyArrayMethod_Spec *StringToStringCastSpec =
484-
get_cast_spec(s2s_name, NPY_NO_CASTING,
485-
NPY_METH_SUPPORTS_UNALIGNED, s2s_dtypes, s2s_slots);
486+
if (this == (PyArray_DTypeMeta *)&StringDType) {
487+
t2t_name = s2s_name;
488+
}
489+
else {
490+
t2t_name = p2p_name;
491+
}
492+
493+
PyArray_DTypeMeta **t2t_dtypes = get_dtypes(this, this);
494+
495+
PyArrayMethod_Spec *ThisToThisCastSpec =
496+
get_cast_spec(t2t_name, NPY_NO_CASTING,
497+
NPY_METH_SUPPORTS_UNALIGNED, t2t_dtypes, s2s_slots);
498+
499+
PyArrayMethod_Spec *ThisToOtherCastSpec = NULL;
500+
PyArrayMethod_Spec *OtherToThisCastSpec = NULL;
501+
502+
int is_pandas = (this == (PyArray_DTypeMeta *)&PandasStringDType);
503+
504+
int num_casts = 5;
505+
506+
if (is_pandas) {
507+
num_casts = 7;
508+
509+
PyArray_DTypeMeta **t2o_dtypes = get_dtypes(this, other);
486510

487-
PyArray_DTypeMeta **u2s_dtypes = get_dtypes(&PyArray_UnicodeDType, NULL);
511+
ThisToOtherCastSpec = get_cast_spec(p2s_name, NPY_NO_CASTING,
512+
NPY_METH_SUPPORTS_UNALIGNED,
513+
t2o_dtypes, s2s_slots);
514+
515+
PyArray_DTypeMeta **o2t_dtypes = get_dtypes(other, this);
516+
517+
OtherToThisCastSpec = get_cast_spec(s2p_name, NPY_NO_CASTING,
518+
NPY_METH_SUPPORTS_UNALIGNED,
519+
o2t_dtypes, s2s_slots);
520+
}
521+
522+
PyArray_DTypeMeta **u2s_dtypes = get_dtypes(&PyArray_UnicodeDType, this);
488523

489524
PyArrayMethod_Spec *UnicodeToStringCastSpec = get_cast_spec(
490525
u2s_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
491526
u2s_dtypes, u2s_slots);
492527

493-
PyArray_DTypeMeta **s2u_dtypes = get_dtypes(NULL, &PyArray_UnicodeDType);
528+
PyArray_DTypeMeta **s2u_dtypes = get_dtypes(this, &PyArray_UnicodeDType);
494529

495530
PyArrayMethod_Spec *StringToUnicodeCastSpec = get_cast_spec(
496531
s2u_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
497532
s2u_dtypes, s2u_slots);
498533

499-
PyArray_DTypeMeta **s2b_dtypes = get_dtypes(NULL, &PyArray_BoolDType);
534+
PyArray_DTypeMeta **s2b_dtypes = get_dtypes(this, &PyArray_BoolDType);
500535

501536
PyArrayMethod_Spec *StringToBoolCastSpec = get_cast_spec(
502537
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
503538
s2b_dtypes, s2b_slots);
504539

505-
PyArrayMethod_Spec **casts = malloc(5 * sizeof(PyArrayMethod_Spec *));
506-
casts[0] = StringToStringCastSpec;
540+
PyArrayMethod_Spec **casts = NULL;
541+
542+
casts = malloc(num_casts * sizeof(PyArrayMethod_Spec *));
543+
544+
casts[0] = ThisToThisCastSpec;
507545
casts[1] = UnicodeToStringCastSpec;
508546
casts[2] = StringToUnicodeCastSpec;
509547
casts[3] = StringToBoolCastSpec;
510-
casts[4] = NULL;
548+
if (is_pandas) {
549+
casts[4] = ThisToOtherCastSpec;
550+
casts[5] = OtherToThisCastSpec;
551+
casts[6] = NULL;
552+
}
553+
else {
554+
casts[4] = NULL;
555+
}
511556

512557
return casts;
513558
}

stringdtype/stringdtype/src/casts.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "numpy/ndarraytypes.h"
1212

1313
PyArrayMethod_Spec **
14-
get_casts(void);
14+
get_casts(PyArray_DTypeMeta *this_dtype, PyArray_DTypeMeta *other_dtype);
1515

1616
size_t
1717
utf8_char_to_ucs4_code(unsigned char *, Py_UCS4 *);

stringdtype/stringdtype/src/dtype.c

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,23 @@ StringDType_type PandasStringDType = {
523523
int
524524
init_string_dtype(void)
525525
{
526-
PyArrayMethod_Spec **casts = get_casts();
526+
PyObject *pandas_mod = PyImport_ImportModule("pandas");
527+
528+
if (pandas_mod == NULL) {
529+
// clear ImportError
530+
PyErr_Clear();
531+
}
532+
else {
533+
PANDAS_AVAILABLE = 1;
534+
}
535+
536+
PyArrayMethod_Spec **StringDType_casts =
537+
get_casts((PyArray_DTypeMeta *)&StringDType, NULL);
527538

528539
PyArrayDTypeMeta_Spec StringDType_DTypeSpec = {
529540
.typeobj = StringScalar_Type,
530541
.slots = StringDType_Slots,
531-
.casts = casts,
542+
.casts = StringDType_casts,
532543
};
533544

534545
/* Loaded dynamically, so may need to be set here: */
@@ -559,20 +570,27 @@ init_string_dtype(void)
559570

560571
StringDType.base.singleton = singleton;
561572

573+
for (int i = 0; StringDType_casts[i] != NULL; i++) {
574+
free(StringDType_casts[i]->dtypes);
575+
free(StringDType_casts[i]);
576+
}
577+
562578
/* and once again for PandasStringDType */
563579

564-
PyObject *mod = PyImport_ImportModule("pandas");
580+
if (PANDAS_AVAILABLE) {
581+
PyArrayMethod_Spec **PandasStringDType_casts =
582+
get_casts((PyArray_DTypeMeta *)&PandasStringDType,
583+
(PyArray_DTypeMeta *)&StringDType);
565584

566-
if (mod != NULL) {
567585
PyArrayDTypeMeta_Spec PandasStringDType_DTypeSpec = {
568586
.typeobj = PandasStringScalar_Type,
569587
.slots = StringDType_Slots,
570-
.casts = casts,
588+
.casts = PandasStringDType_casts,
571589
};
572590

573-
PyObject *pandas_na_obj = PyObject_GetAttrString(mod, "NA");
591+
PyObject *pandas_na_obj = PyObject_GetAttrString(pandas_mod, "NA");
574592

575-
Py_DECREF(mod);
593+
Py_DECREF(pandas_mod);
576594

577595
if (pandas_na_obj == NULL) {
578596
return -1;
@@ -605,15 +623,11 @@ init_string_dtype(void)
605623
}
606624

607625
PandasStringDType.base.singleton = singleton;
608-
PANDAS_AVAILABLE = 1;
609-
}
610-
else {
611-
PyErr_Clear();
612-
}
613626

614-
for (int i = 0; casts[i] != NULL; i++) {
615-
free(casts[i]->dtypes);
616-
free(casts[i]);
627+
for (int i = 0; PandasStringDType_casts[i] != NULL; i++) {
628+
free(PandasStringDType_casts[i]->dtypes);
629+
free(PandasStringDType_casts[i]);
630+
}
617631
}
618632

619633
return 0;

stringdtype/tests/test_stringdtype.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,18 @@ def dtype(request):
3131
if request.param == "StringDType":
3232
return StringDType()
3333
elif request.param == "PandasStringDType":
34-
try:
35-
from stringdtype import PandasStringDType
34+
pytest.importorskip("pandas")
35+
from stringdtype import PandasStringDType
3636

37-
return PandasStringDType()
38-
except ImportError:
39-
pytest.skip("cannot import pandas")
37+
return PandasStringDType()
4038

4139

4240
@pytest.fixture
4341
def scalar(dtype):
4442
if dtype == StringDType():
4543
return StringScalar
46-
try:
47-
from stringdtype import PandasStringDType
48-
49-
del PandasStringDType
44+
else:
5045
return PandasStringScalar
51-
except ImportError:
52-
pytest.skip("cannot import pandas")
5346

5447

5548
def test_scalar_creation(scalar):
@@ -77,6 +70,8 @@ def test_array_creation_utf8(dtype, data):
7770

7871

7972
def test_array_creation_scalars(string_list, scalar, dtype):
73+
if not issubclass(scalar, dtype.type):
74+
pytest.skip()
8075
arr = np.array([scalar(s) for s in string_list])
8176
assert repr(arr) == repr(np.array(string_list, dtype=dtype))
8277

@@ -383,3 +378,25 @@ def test_create_with_na(dtype, na_val):
383378
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
384379
)
385380
assert arr[1] is dtype.na_object
381+
382+
383+
def test_pandas_to_numpy_cast(string_list):
384+
pytest.importorskip("pandas")
385+
386+
from stringdtype import PandasStringDType
387+
388+
sarr = np.array(string_list, dtype=StringDType())
389+
390+
parr = sarr.astype(PandasStringDType())
391+
392+
np.testing.assert_array_equal(
393+
parr, np.array(string_list, dtype=PandasStringDType())
394+
)
395+
np.testing.assert_array_equal(sarr, parr.astype(StringDType()))
396+
397+
# check that NA converts correctly too
398+
sarr[1] = StringDType.na_object
399+
parr = sarr.astype(PandasStringDType())
400+
401+
assert parr[1] is PandasStringDType.na_object
402+
assert parr.astype(StringDType())[1] is StringDType.na_object

0 commit comments

Comments
 (0)