Skip to content

Commit e72a3c1

Browse files
committed
add round-trip casts for datetime64
1 parent 015c473 commit e72a3c1

File tree

2 files changed

+176
-5
lines changed

2 files changed

+176
-5
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -907,17 +907,152 @@ STRING_TO_FLOAT_RESOLVE_DESCRIPTORS(float16, HALF)
907907
STRING_TO_FLOAT_CAST(float16, f16, npy_half_isinf, npy_double_to_half)
908908
FLOAT_TO_STRING_CAST(float16, f16, npy_half_to_double)
909909

910+
// string to datetime
911+
912+
static NPY_CASTING
913+
string_to_datetime_resolve_descriptors(
914+
PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
915+
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
916+
npy_intp *NPY_UNUSED(view_offset))
917+
{
918+
if (given_descrs[1] == NULL) {
919+
PyErr_SetString(PyExc_TypeError,
920+
"Casting from StringDType to datetimes without a unit "
921+
"is not currently supported");
922+
return (NPY_CASTING)-1;
923+
}
924+
else {
925+
Py_INCREF(given_descrs[1]);
926+
loop_descrs[1] = given_descrs[1];
927+
}
928+
929+
Py_INCREF(given_descrs[0]);
930+
loop_descrs[0] = given_descrs[0];
931+
932+
return NPY_UNSAFE_CASTING;
933+
}
934+
935+
static int
936+
string_to_datetime(PyArrayMethod_Context *context, char *const data[],
937+
npy_intp const dimensions[], npy_intp const strides[],
938+
NpyAuxData *NPY_UNUSED(auxdata))
939+
{
940+
npy_intp N = dimensions[0];
941+
char *in = data[0];
942+
npy_datetime *out = (npy_datetime *)data[1];
943+
944+
npy_intp in_stride = strides[0];
945+
npy_intp out_stride = strides[1] / sizeof(npy_datetime);
946+
947+
ss *s = NULL;
948+
npy_datetimestruct dts;
949+
NPY_DATETIMEUNIT in_unit = -1;
950+
PyArray_DatetimeMetaData in_meta = {0, 1};
951+
npy_bool out_special;
952+
953+
PyArray_Descr *dt_descr = context->descriptors[1];
954+
PyArray_DatetimeMetaData *dt_meta =
955+
&(((PyArray_DatetimeDTypeMetaData *)dt_descr->c_metadata)->meta);
956+
957+
while (N--) {
958+
s = (ss *)in;
959+
if (ss_isnull(s)) {
960+
*out = NPY_DATETIME_NAT;
961+
}
962+
if (NpyDatetime_ParseISO8601Datetime(
963+
(const char *)s->buf, s->len, in_unit, NPY_UNSAFE_CASTING,
964+
&dts, &in_meta.base, &out_special) < 0) {
965+
return -1;
966+
}
967+
if (NpyDatetime_ConvertDatetimeStructToDatetime64(dt_meta, &dts, out) <
968+
0) {
969+
return -1;
970+
}
971+
972+
in += in_stride;
973+
out += out_stride;
974+
}
975+
976+
return 0;
977+
}
978+
979+
static PyType_Slot s2dt_slots[] = {
980+
{NPY_METH_resolve_descriptors,
981+
&string_to_datetime_resolve_descriptors},
982+
{NPY_METH_strided_loop, &string_to_datetime},
983+
{0, NULL}};
984+
985+
static char *s2dt_name = "cast_StringDType_to_Datetime";
986+
987+
// datetime to string
988+
989+
static int
990+
datetime_to_string(PyArrayMethod_Context *context, char *const data[],
991+
npy_intp const dimensions[], npy_intp const strides[],
992+
NpyAuxData *NPY_UNUSED(auxdata))
993+
{
994+
npy_intp N = dimensions[0];
995+
npy_datetime *in = (npy_datetime *)data[0];
996+
char *out = data[1];
997+
998+
npy_intp in_stride = strides[0] / sizeof(npy_datetime);
999+
npy_intp out_stride = strides[1];
1000+
1001+
npy_datetimestruct dts;
1002+
PyArray_Descr *dt_descr = context->descriptors[0];
1003+
PyArray_DatetimeMetaData *dt_meta =
1004+
&(((PyArray_DatetimeDTypeMetaData *)dt_descr->c_metadata)->meta);
1005+
// buffer passed to numpy to build datetime string
1006+
char datetime_buf[NPY_DATETIME_MAX_ISO8601_STRLEN];
1007+
1008+
while (N--) {
1009+
ss *out_ss = (ss *)out;
1010+
ssfree(out_ss);
1011+
if (*in == NPY_DATETIME_NAT) {
1012+
/* convert to NA */
1013+
out_ss = NULL;
1014+
}
1015+
if (NpyDatetime_ConvertDatetime64ToDatetimeStruct(dt_meta, *in, &dts) <
1016+
0) {
1017+
return -1;
1018+
}
1019+
1020+
// zero out buffer
1021+
memset(datetime_buf, 0, NPY_DATETIME_MAX_ISO8601_STRLEN);
1022+
1023+
if (NpyDatetime_MakeISO8601Datetime(
1024+
&dts, datetime_buf, NPY_DATETIME_MAX_ISO8601_STRLEN, 0, 0,
1025+
dt_meta->base, -1, NPY_UNSAFE_CASTING) < 0) {
1026+
return -1;
1027+
}
1028+
1029+
if (ssnewlen(datetime_buf, strlen(datetime_buf), out_ss) < 0) {
1030+
PyErr_SetString(PyExc_MemoryError, "ssnewlen failed");
1031+
return -1;
1032+
}
1033+
1034+
in += in_stride;
1035+
out += out_stride;
1036+
}
1037+
1038+
return 0;
1039+
}
1040+
1041+
static PyType_Slot dt2s_slots[] = {
1042+
{NPY_METH_resolve_descriptors,
1043+
&any_to_string_UNSAFE_resolve_descriptors},
1044+
{NPY_METH_strided_loop, &datetime_to_string},
1045+
{0, NULL}};
1046+
1047+
static char *dt2s_name = "cast_Datetime_to_StringDType";
1048+
9101049
// TODO: longdouble
9111050
// punting on this one because numpy's C routines for handling
9121051
// longdouble are not public (specifically NumPyOS_ascii_strtold)
9131052
// also this type is kinda niche and is not needed by pandas
9141053
//
9151054
// cfloat, cdouble, and clongdouble
9161055
// not hard to do in principle but not needed by pandas.
917-
//
918-
// datetime
919-
// numpy's utilities for parsing a string into a datetime
920-
// are not public (specifically parse_iso_8601_datetime).
9211056

9221057
PyArrayMethod_Spec *
9231058
get_cast_spec(const char *name, NPY_CASTING casting,
@@ -961,7 +1096,7 @@ get_casts()
9611096
get_cast_spec(t2t_name, NPY_NO_CASTING,
9621097
NPY_METH_SUPPORTS_UNALIGNED, t2t_dtypes, s2s_slots);
9631098

964-
int num_casts = 27;
1099+
int num_casts = 29;
9651100

9661101
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
9671102
num_casts += 4;
@@ -1033,6 +1168,22 @@ get_casts()
10331168
DTYPES_AND_CAST_SPEC(f32, Float)
10341169
DTYPES_AND_CAST_SPEC(f16, Half)
10351170

1171+
PyArray_DTypeMeta **s2dt_dtypes = get_dtypes(
1172+
(PyArray_DTypeMeta *)&StringDType, &PyArray_DatetimeDType);
1173+
1174+
PyArrayMethod_Spec *StringToDatetimeCastSpec = get_cast_spec(
1175+
s2dt_name, NPY_UNSAFE_CASTING,
1176+
NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI,
1177+
s2dt_dtypes, s2dt_slots);
1178+
1179+
PyArray_DTypeMeta **dt2s_dtypes = get_dtypes(
1180+
&PyArray_DatetimeDType, (PyArray_DTypeMeta *)&StringDType);
1181+
1182+
PyArrayMethod_Spec *DatetimeToStringCastSpec = get_cast_spec(
1183+
dt2s_name, NPY_UNSAFE_CASTING,
1184+
NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI,
1185+
dt2s_dtypes, dt2s_slots);
1186+
10361187
PyArrayMethod_Spec **casts =
10371188
malloc((num_casts + 1) * sizeof(PyArrayMethod_Spec *));
10381189

@@ -1089,6 +1240,8 @@ get_casts()
10891240
casts[cast_i++] = FloatToStringCastSpec;
10901241
casts[cast_i++] = StringToHalfCastSpec;
10911242
casts[cast_i++] = HalfToStringCastSpec;
1243+
casts[cast_i++] = StringToDatetimeCastSpec;
1244+
casts[cast_i++] = DatetimeToStringCastSpec;
10921245
casts[cast_i++] = NULL;
10931246

10941247
assert(casts[num_casts] == NULL);

stringdtype/tests/test_stringdtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,21 @@ def test_create_with_na(dtype):
515515
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
516516
)
517517
assert arr[1] is dtype.na_object
518+
519+
520+
def test_datetime_cast(dtype):
521+
a = np.array(
522+
[
523+
np.datetime64("1923-04-14T12:43:12"),
524+
np.datetime64("1994-06-21T14:43:15"),
525+
np.datetime64("2001-10-15T04:10:32"),
526+
np.datetime64("1995-11-25T16:02:16"),
527+
np.datetime64("2005-01-04T03:14:12"),
528+
np.datetime64("2041-12-03T14:05:03"),
529+
]
530+
)
531+
sa = a.astype(dtype)
532+
ra = sa.astype(a.dtype)
533+
534+
np.testing.assert_array_equal(a, ra)
535+
np.testing.assert_array_equal(sa, a.astype("U"))

0 commit comments

Comments
 (0)