Skip to content

Commit a1c811f

Browse files
authored
Merge pull request #79 from ngoldbaum/datetime-cast
Add a roundtrip cast for datetimes
2 parents 676c2fd + afa076d commit a1c811f

File tree

6 files changed

+199
-22
lines changed

6 files changed

+199
-22
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 160 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -907,17 +907,154 @@ STRING_TO_FLOAT_RESOLVE_DESCRIPTORS(float16, HALF)
907907
STRING_TO_FLOAT_CAST(float16, f16, npy_half_isinf, npy_double_to_half)
908908
FLOAT_TO_STRING_CAST(float16, f16, npy_half_to_double)
909909

910+
// string to datetime
911+
912+
static NPY_CASTING
913+
string_to_datetime_resolve_descriptors(
914+
PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
915+
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
916+
npy_intp *NPY_UNUSED(view_offset))
917+
{
918+
if (given_descrs[1] == NULL) {
919+
PyErr_SetString(PyExc_TypeError,
920+
"Casting from StringDType to datetimes without a unit "
921+
"is not currently supported");
922+
return (NPY_CASTING)-1;
923+
}
924+
else {
925+
Py_INCREF(given_descrs[1]);
926+
loop_descrs[1] = given_descrs[1];
927+
}
928+
929+
Py_INCREF(given_descrs[0]);
930+
loop_descrs[0] = given_descrs[0];
931+
932+
return NPY_UNSAFE_CASTING;
933+
}
934+
935+
static int
936+
string_to_datetime(PyArrayMethod_Context *context, char *const data[],
937+
npy_intp const dimensions[], npy_intp const strides[],
938+
NpyAuxData *NPY_UNUSED(auxdata))
939+
{
940+
npy_intp N = dimensions[0];
941+
char *in = data[0];
942+
npy_datetime *out = (npy_datetime *)data[1];
943+
944+
npy_intp in_stride = strides[0];
945+
npy_intp out_stride = strides[1] / sizeof(npy_datetime);
946+
947+
ss *s = NULL;
948+
npy_datetimestruct dts;
949+
NPY_DATETIMEUNIT in_unit = -1;
950+
PyArray_DatetimeMetaData in_meta = {0, 1};
951+
npy_bool out_special;
952+
953+
PyArray_Descr *dt_descr = context->descriptors[1];
954+
PyArray_DatetimeMetaData *dt_meta =
955+
&(((PyArray_DatetimeDTypeMetaData *)dt_descr->c_metadata)->meta);
956+
957+
while (N--) {
958+
s = (ss *)in;
959+
if (ss_isnull(s)) {
960+
*out = NPY_DATETIME_NAT;
961+
}
962+
if (NpyDatetime_ParseISO8601Datetime(
963+
(const char *)s->buf, s->len, in_unit, NPY_UNSAFE_CASTING,
964+
&dts, &in_meta.base, &out_special) < 0) {
965+
return -1;
966+
}
967+
if (NpyDatetime_ConvertDatetimeStructToDatetime64(dt_meta, &dts, out) <
968+
0) {
969+
return -1;
970+
}
971+
972+
in += in_stride;
973+
out += out_stride;
974+
}
975+
976+
return 0;
977+
}
978+
979+
static PyType_Slot s2dt_slots[] = {
980+
{NPY_METH_resolve_descriptors,
981+
&string_to_datetime_resolve_descriptors},
982+
{NPY_METH_strided_loop, &string_to_datetime},
983+
{0, NULL}};
984+
985+
static char *s2dt_name = "cast_StringDType_to_Datetime";
986+
987+
// datetime to string
988+
989+
static int
990+
datetime_to_string(PyArrayMethod_Context *context, char *const data[],
991+
npy_intp const dimensions[], npy_intp const strides[],
992+
NpyAuxData *NPY_UNUSED(auxdata))
993+
{
994+
npy_intp N = dimensions[0];
995+
npy_datetime *in = (npy_datetime *)data[0];
996+
char *out = data[1];
997+
998+
npy_intp in_stride = strides[0] / sizeof(npy_datetime);
999+
npy_intp out_stride = strides[1];
1000+
1001+
npy_datetimestruct dts;
1002+
PyArray_Descr *dt_descr = context->descriptors[0];
1003+
PyArray_DatetimeMetaData *dt_meta =
1004+
&(((PyArray_DatetimeDTypeMetaData *)dt_descr->c_metadata)->meta);
1005+
// buffer passed to numpy to build datetime string
1006+
char datetime_buf[NPY_DATETIME_MAX_ISO8601_STRLEN];
1007+
1008+
while (N--) {
1009+
ss *out_ss = (ss *)out;
1010+
ssfree(out_ss);
1011+
if (*in == NPY_DATETIME_NAT) {
1012+
/* convert to NA */
1013+
out_ss = NULL;
1014+
}
1015+
else {
1016+
if (NpyDatetime_ConvertDatetime64ToDatetimeStruct(dt_meta, *in,
1017+
&dts) < 0) {
1018+
return -1;
1019+
}
1020+
1021+
// zero out buffer
1022+
memset(datetime_buf, 0, NPY_DATETIME_MAX_ISO8601_STRLEN);
1023+
1024+
if (NpyDatetime_MakeISO8601Datetime(
1025+
&dts, datetime_buf, NPY_DATETIME_MAX_ISO8601_STRLEN, 0,
1026+
0, dt_meta->base, -1, NPY_UNSAFE_CASTING) < 0) {
1027+
return -1;
1028+
}
1029+
1030+
if (ssnewlen(datetime_buf, strlen(datetime_buf), out_ss) < 0) {
1031+
PyErr_SetString(PyExc_MemoryError, "ssnewlen failed");
1032+
return -1;
1033+
}
1034+
}
1035+
1036+
in += in_stride;
1037+
out += out_stride;
1038+
}
1039+
1040+
return 0;
1041+
}
1042+
1043+
static PyType_Slot dt2s_slots[] = {
1044+
{NPY_METH_resolve_descriptors,
1045+
&any_to_string_UNSAFE_resolve_descriptors},
1046+
{NPY_METH_strided_loop, &datetime_to_string},
1047+
{0, NULL}};
1048+
1049+
static char *dt2s_name = "cast_Datetime_to_StringDType";
1050+
9101051
// TODO: longdouble
9111052
// punting on this one because numpy's C routines for handling
9121053
// longdouble are not public (specifically NumPyOS_ascii_strtold)
9131054
// also this type is kinda niche and is not needed by pandas
9141055
//
9151056
// cfloat, cdouble, and clongdouble
9161057
// not hard to do in principle but not needed by pandas.
917-
//
918-
// datetime
919-
// numpy's utilities for parsing a string into a datetime
920-
// are not public (specifically parse_iso_8601_datetime).
9211058

9221059
PyArrayMethod_Spec *
9231060
get_cast_spec(const char *name, NPY_CASTING casting,
@@ -961,7 +1098,7 @@ get_casts()
9611098
get_cast_spec(t2t_name, NPY_NO_CASTING,
9621099
NPY_METH_SUPPORTS_UNALIGNED, t2t_dtypes, s2s_slots);
9631100

964-
int num_casts = 27;
1101+
int num_casts = 29;
9651102

9661103
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
9671104
num_casts += 4;
@@ -1033,6 +1170,22 @@ get_casts()
10331170
DTYPES_AND_CAST_SPEC(f32, Float)
10341171
DTYPES_AND_CAST_SPEC(f16, Half)
10351172

1173+
PyArray_DTypeMeta **s2dt_dtypes = get_dtypes(
1174+
(PyArray_DTypeMeta *)&StringDType, &PyArray_DatetimeDType);
1175+
1176+
PyArrayMethod_Spec *StringToDatetimeCastSpec = get_cast_spec(
1177+
s2dt_name, NPY_UNSAFE_CASTING,
1178+
NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI,
1179+
s2dt_dtypes, s2dt_slots);
1180+
1181+
PyArray_DTypeMeta **dt2s_dtypes = get_dtypes(
1182+
&PyArray_DatetimeDType, (PyArray_DTypeMeta *)&StringDType);
1183+
1184+
PyArrayMethod_Spec *DatetimeToStringCastSpec = get_cast_spec(
1185+
dt2s_name, NPY_UNSAFE_CASTING,
1186+
NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI,
1187+
dt2s_dtypes, dt2s_slots);
1188+
10361189
PyArrayMethod_Spec **casts =
10371190
malloc((num_casts + 1) * sizeof(PyArrayMethod_Spec *));
10381191

@@ -1089,6 +1242,8 @@ get_casts()
10891242
casts[cast_i++] = FloatToStringCastSpec;
10901243
casts[cast_i++] = StringToHalfCastSpec;
10911244
casts[cast_i++] = HalfToStringCastSpec;
1245+
casts[cast_i++] = StringToDatetimeCastSpec;
1246+
casts[cast_i++] = DatetimeToStringCastSpec;
10921247
casts[cast_i++] = NULL;
10931248

10941249
assert(casts[num_casts] == NULL);

stringdtype/stringdtype/src/casts.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#ifndef _NPY_CASTS_H
22
#define _NPY_CASTS_H
33

4+
// needed for Py_UCS4
45
#include <Python.h>
56

7+
// need these defines and includes for PyArrayMethod_Spec
68
#define PY_ARRAY_UNIQUE_SYMBOL stringdtype_ARRAY_API
7-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
9+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
10+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
811
#define NO_IMPORT_ARRAY
9-
#include "numpy/arrayobject.h"
1012
#include "numpy/experimental_dtype_api.h"
11-
#include "numpy/halffloat.h"
12-
#include "numpy/ndarraytypes.h"
1313

1414
PyArrayMethod_Spec **
1515
get_casts();

stringdtype/stringdtype/src/dtype.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
// clang-format on
88

99
#define PY_ARRAY_UNIQUE_SYMBOL stringdtype_ARRAY_API
10-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
10+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
11+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
1112
#define NO_IMPORT_ARRAY
1213
#include "numpy/arrayobject.h"
1314
#include "numpy/experimental_dtype_api.h"
15+
#include "numpy/halffloat.h"
1416
#include "numpy/ndarraytypes.h"
1517
#include "numpy/npy_math.h"
18+
#include "numpy/ufuncobject.h"
1619

1720
typedef struct {
1821
PyArray_Descr base;

stringdtype/stringdtype/src/main.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include <Python.h>
22

33
#define PY_ARRAY_UNIQUE_SYMBOL stringdtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
5+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
56
#include "numpy/arrayobject.h"
67
#include "numpy/experimental_dtype_api.h"
78

@@ -88,9 +89,8 @@ static struct PyModuleDef moduledef = {
8889
PyMODINIT_FUNC
8990
PyInit__main(void)
9091
{
91-
if (_import_array() < 0) {
92-
return NULL;
93-
}
92+
import_array();
93+
9494
if (import_experimental_dtype_api(13) < 0) {
9595
return NULL;
9696
}

stringdtype/stringdtype/src/umath.c

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
11
#include <Python.h>
22

3-
#define PY_ARRAY_UNIQUE_SYMBOL stringdtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
5-
#define NO_IMPORT_ARRAY
6-
#include "numpy/arrayobject.h"
7-
#include "numpy/experimental_dtype_api.h"
8-
#include "numpy/ndarraytypes.h"
9-
#include "numpy/ufuncobject.h"
3+
#include "umath.h"
104

115
#include "dtype.h"
126
#include "static_string.h"
137
#include "string.h"
14-
#include "umath.h"
158

169
static NPY_CASTING
1710
multiply_resolve_descriptors(

stringdtype/tests/test_stringdtype.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,29 @@ def test_create_with_na(dtype):
515515
== f"array(['hello', {dtype.na_object}, 'world'], dtype={dtype})"
516516
)
517517
assert arr[1] is dtype.na_object
518+
519+
520+
def test_datetime_cast(dtype):
521+
a = np.array(
522+
[
523+
np.datetime64("1923-04-14T12:43:12"),
524+
np.datetime64("1994-06-21T14:43:15"),
525+
np.datetime64("2001-10-15T04:10:32"),
526+
np.datetime64("NaT"),
527+
np.datetime64("1995-11-25T16:02:16"),
528+
np.datetime64("2005-01-04T03:14:12"),
529+
np.datetime64("2041-12-03T14:05:03"),
530+
]
531+
)
532+
sa = a.astype(dtype)
533+
assert sa[3] is dtype.na_object
534+
535+
ra = sa.astype(a.dtype)
536+
assert np.isnat(ra[3])
537+
538+
np.testing.assert_array_equal(a, ra)
539+
540+
# don't worry about comparing how NaT is converted
541+
sa = np.delete(sa, 3)
542+
a = np.delete(a, 3)
543+
np.testing.assert_array_equal(sa, a.astype("U"))

0 commit comments

Comments
 (0)