Skip to content

Commit db6dcb2

Browse files
committed
refactor stringdtype cast setup to reduce boilerplate
1 parent 0bd2328 commit db6dcb2

File tree

1 file changed

+54
-41
lines changed

1 file changed

+54
-41
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ gil_error(PyObject *type, const char *msg)
1212
PyGILState_Release(gstate);
1313
}
1414

15+
// string to string
16+
1517
static NPY_CASTING
1618
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
1719
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -67,23 +69,17 @@ string_to_string(PyArrayMethod_Context *NPY_UNUSED(context),
6769
return 0;
6870
}
6971

70-
static PyArray_DTypeMeta *s2s_dtypes[2] = {NULL, NULL};
71-
7272
static PyType_Slot s2s_slots[] = {
7373
{NPY_METH_resolve_descriptors, &string_to_string_resolve_descriptors},
7474
{NPY_METH_strided_loop, &string_to_string},
7575
{NPY_METH_unaligned_strided_loop, &string_to_string},
7676
{0, NULL}};
7777

78-
PyArrayMethod_Spec StringToStringCastSpec = {
79-
.name = "cast_StringDType_to_StringDType",
80-
.nin = 1,
81-
.nout = 1,
82-
.casting = NPY_NO_CASTING,
83-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
84-
.dtypes = s2s_dtypes,
85-
.slots = s2s_slots,
86-
};
78+
static char *s2s_name = "cast_StringDType_to_StringDType";
79+
80+
static PyArray_DTypeMeta *s2s_dtypes[] = {NULL, NULL};
81+
82+
// unicode to string
8783

8884
static NPY_CASTING
8985
unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -261,6 +257,8 @@ static PyType_Slot u2s_slots[] = {
261257

262258
static char *u2s_name = "cast_Unicode_to_StringDType";
263259

260+
// string to unicode
261+
264262
static NPY_CASTING
265263
string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
266264
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -376,41 +374,56 @@ static PyType_Slot s2u_slots[] = {
376374

377375
static char *s2u_name = "cast_StringDType_to_Unicode";
378376

377+
PyArrayMethod_Spec *
378+
get_cast_spec(const char *name, NPY_CASTING casting,
379+
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
380+
PyType_Slot *slots)
381+
{
382+
PyArrayMethod_Spec *ret = malloc(sizeof(PyArrayMethod_Spec));
383+
384+
ret->name = name;
385+
ret->nin = 1;
386+
ret->nout = 1;
387+
ret->casting = casting;
388+
ret->flags = flags;
389+
ret->dtypes = dtypes;
390+
ret->slots = slots;
391+
392+
return ret;
393+
}
394+
395+
PyArray_DTypeMeta **
396+
get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
397+
{
398+
PyArray_DTypeMeta **ret = malloc(2 * sizeof(PyArray_DTypeMeta *));
399+
400+
ret[0] = dt1;
401+
ret[1] = dt2;
402+
403+
return ret;
404+
}
405+
379406
PyArrayMethod_Spec **
380407
get_casts(void)
381408
{
382-
PyArray_DTypeMeta **u2s_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
383-
u2s_dtypes[0] = &PyArray_UnicodeDType;
384-
u2s_dtypes[1] = NULL;
385-
386-
PyArrayMethod_Spec *UnicodeToStringCastSpec =
387-
malloc(sizeof(PyArrayMethod_Spec));
388-
389-
UnicodeToStringCastSpec->name = u2s_name;
390-
UnicodeToStringCastSpec->nin = 1;
391-
UnicodeToStringCastSpec->nout = 1;
392-
UnicodeToStringCastSpec->casting = NPY_SAFE_CASTING;
393-
UnicodeToStringCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
394-
UnicodeToStringCastSpec->dtypes = u2s_dtypes;
395-
UnicodeToStringCastSpec->slots = u2s_slots;
396-
397-
PyArray_DTypeMeta **s2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
398-
s2u_dtypes[0] = NULL;
399-
s2u_dtypes[1] = &PyArray_UnicodeDType;
400-
401-
PyArrayMethod_Spec *StringToUnicodeCastSpec =
402-
malloc(sizeof(PyArrayMethod_Spec));
403-
404-
StringToUnicodeCastSpec->name = s2u_name;
405-
StringToUnicodeCastSpec->nin = 1;
406-
StringToUnicodeCastSpec->nout = 1;
407-
StringToUnicodeCastSpec->casting = NPY_SAFE_CASTING;
408-
StringToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
409-
StringToUnicodeCastSpec->dtypes = s2u_dtypes;
410-
StringToUnicodeCastSpec->slots = s2u_slots;
409+
PyArrayMethod_Spec *StringToStringCastSpec =
410+
get_cast_spec(s2s_name, NPY_NO_CASTING,
411+
NPY_METH_SUPPORTS_UNALIGNED, s2s_dtypes, s2s_slots);
412+
413+
PyArray_DTypeMeta **u2s_dtypes = get_dtypes(&PyArray_UnicodeDType, NULL);
414+
415+
PyArrayMethod_Spec *UnicodeToStringCastSpec = get_cast_spec(
416+
u2s_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
417+
u2s_dtypes, u2s_slots);
418+
419+
PyArray_DTypeMeta **s2u_dtypes = get_dtypes(NULL, &PyArray_UnicodeDType);
420+
421+
PyArrayMethod_Spec *StringToUnicodeCastSpec = get_cast_spec(
422+
s2u_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
423+
s2u_dtypes, s2u_slots);
411424

412425
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
413-
casts[0] = &StringToStringCastSpec;
426+
casts[0] = StringToStringCastSpec;
414427
casts[1] = UnicodeToStringCastSpec;
415428
casts[2] = StringToUnicodeCastSpec;
416429
casts[3] = NULL;

0 commit comments

Comments
 (0)