Skip to content

Commit 632c6d1

Browse files
committed
add casts for integer dtypes that don't have a sized alias
1 parent 7af3741 commit 632c6d1

File tree

2 files changed

+222
-26
lines changed

2 files changed

+222
-26
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 211 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,42 @@ INT_TO_STRING(int32, int, i32, long long)
689689
STRING_TO_INT(int64, int, i64, NPY_INT64, lli, npy_longlong)
690690
INT_TO_STRING(int64, int, i64, long long)
691691

692+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
693+
// byte doesn't have a bitsized alias
694+
STRING_TO_INT(byte, int, byte, NPY_BYTE, lli, npy_byte)
695+
INT_TO_STRING(byte, int, byte, long long)
696+
697+
STRING_TO_INT(ubyte, uint, ubyte, NPY_UBYTE, llu, npy_ubyte)
698+
INT_TO_STRING(ubyte, uint, ubyte, unsigned long long)
699+
#endif
700+
701+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
702+
// short doesn't have a bitsized alias
703+
STRING_TO_INT(short, int, short, NPY_SHORT, lli, npy_short)
704+
INT_TO_STRING(short, int, short, long long)
705+
706+
STRING_TO_INT(ushort, uint, ushort, NPY_USHORT, llu, npy_ushort)
707+
INT_TO_STRING(ushort, uint, ushort, unsigned long long)
708+
#endif
709+
710+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
711+
// int doesn't have a bitsized alias
712+
STRING_TO_INT(int, int, int, NPY_INT, lli, npy_int)
713+
INT_TO_STRING(int, int, int, long long)
714+
715+
STRING_TO_INT(uint, uint, uint, NPY_UINT, llu, npy_uint)
716+
INT_TO_STRING(uint, uint, uint, unsigned long long)
717+
#endif
718+
719+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
720+
// long long doesn't have a bitsized alias
721+
STRING_TO_INT(longlong, int, longlong, NPY_LONGLONG, lli, npy_longlong)
722+
INT_TO_STRING(longlong, int, longlong, long long)
723+
724+
STRING_TO_INT(ulonglong, uint, ulonglong, NPY_ULONGLONG, llu, npy_ulonglong)
725+
INT_TO_STRING(ulonglong, uint, ulonglong, unsigned long long)
726+
#endif
727+
692728
STRING_TO_INT(uint8, uint, ui8, NPY_UINT8, llu, npy_ulonglong)
693729
INT_TO_STRING(uint8, uint, ui8, unsigned long long)
694730

@@ -755,6 +791,19 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
755791

756792
int num_casts = 21;
757793

794+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
795+
num_casts += 4;
796+
#endif
797+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
798+
num_casts += 4;
799+
#endif
800+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
801+
num_casts += 4;
802+
#endif
803+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
804+
num_casts += 4;
805+
#endif
806+
758807
if (is_pandas) {
759808
num_casts += 2;
760809

@@ -879,6 +928,116 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
879928
get_cast_spec(i642s_name, NPY_UNSAFE_CASTING,
880929
NPY_METH_REQUIRES_PYAPI, i642s_dtypes, i642s_slots);
881930

931+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
932+
PyArray_DTypeMeta **s2byte_dtypes = get_dtypes(this, &PyArray_ByteDType);
933+
934+
PyArrayMethod_Spec *StringToByteCastSpec = get_cast_spec(
935+
s2byte_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
936+
s2byte_dtypes, s2byte_slots);
937+
938+
PyArray_DTypeMeta **byte2s_dtypes = get_dtypes(&PyArray_ByteDType, this);
939+
940+
PyArrayMethod_Spec *ByteToStringCastSpec = get_cast_spec(
941+
byte2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
942+
byte2s_dtypes, byte2s_slots);
943+
944+
PyArray_DTypeMeta **s2ubyte_dtypes = get_dtypes(this, &PyArray_UByteDType);
945+
946+
PyArrayMethod_Spec *StringToUByteCastSpec = get_cast_spec(
947+
s2ubyte_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
948+
s2ubyte_dtypes, s2ubyte_slots);
949+
950+
PyArray_DTypeMeta **ubyte2s_dtypes = get_dtypes(&PyArray_UByteDType, this);
951+
952+
PyArrayMethod_Spec *UByteToStringCastSpec = get_cast_spec(
953+
ubyte2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
954+
ubyte2s_dtypes, ubyte2s_slots);
955+
#endif
956+
957+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
958+
PyArray_DTypeMeta **s2short_dtypes = get_dtypes(this, &PyArray_ShortDType);
959+
960+
PyArrayMethod_Spec *StringToShortCastSpec = get_cast_spec(
961+
s2short_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
962+
s2short_dtypes, s2short_slots);
963+
964+
PyArray_DTypeMeta **short2s_dtypes = get_dtypes(&PyArray_ShortDType, this);
965+
966+
PyArrayMethod_Spec *ShortToStringCastSpec = get_cast_spec(
967+
short2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
968+
short2s_dtypes, short2s_slots);
969+
970+
PyArray_DTypeMeta **s2ushort_dtypes =
971+
get_dtypes(this, &PyArray_UShortDType);
972+
973+
PyArrayMethod_Spec *StringToUShortCastSpec = get_cast_spec(
974+
s2ushort_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
975+
s2ushort_dtypes, s2ushort_slots);
976+
977+
PyArray_DTypeMeta **ushort2s_dtypes =
978+
get_dtypes(&PyArray_UShortDType, this);
979+
980+
PyArrayMethod_Spec *UShortToStringCastSpec = get_cast_spec(
981+
ushort2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
982+
ushort2s_dtypes, ushort2s_slots);
983+
#endif
984+
985+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
986+
PyArray_DTypeMeta **s2int_dtypes = get_dtypes(this, &PyArray_IntDType);
987+
988+
PyArrayMethod_Spec *StringToIntCastSpec =
989+
get_cast_spec(s2int_name, NPY_UNSAFE_CASTING,
990+
NPY_METH_REQUIRES_PYAPI, s2int_dtypes, s2int_slots);
991+
992+
PyArray_DTypeMeta **int2s_dtypes = get_dtypes(&PyArray_IntDType, this);
993+
994+
PyArrayMethod_Spec *IntToStringCastSpec =
995+
get_cast_spec(int2s_name, NPY_UNSAFE_CASTING,
996+
NPY_METH_REQUIRES_PYAPI, int2s_dtypes, int2s_slots);
997+
998+
PyArray_DTypeMeta **s2uint_dtypes = get_dtypes(this, &PyArray_UIntDType);
999+
1000+
PyArrayMethod_Spec *StringToUIntCastSpec = get_cast_spec(
1001+
s2uint_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1002+
s2uint_dtypes, s2uint_slots);
1003+
1004+
PyArray_DTypeMeta **uint2s_dtypes = get_dtypes(&PyArray_UIntDType, this);
1005+
1006+
PyArrayMethod_Spec *UIntToStringCastSpec = get_cast_spec(
1007+
uint2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1008+
uint2s_dtypes, uint2s_slots);
1009+
#endif
1010+
1011+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
1012+
PyArray_DTypeMeta **s2longlong_dtypes =
1013+
get_dtypes(this, &PyArray_LongLongDType);
1014+
1015+
PyArrayMethod_Spec *StringToLongLongCastSpec = get_cast_spec(
1016+
s2longlong_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1017+
s2longlong_dtypes, s2longlong_slots);
1018+
1019+
PyArray_DTypeMeta **longlong2s_dtypes =
1020+
get_dtypes(&PyArray_LongLongDType, this);
1021+
1022+
PyArrayMethod_Spec *LongLongToStringCastSpec = get_cast_spec(
1023+
longlong2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1024+
longlong2s_dtypes, longlong2s_slots);
1025+
1026+
PyArray_DTypeMeta **s2ulonglong_dtypes =
1027+
get_dtypes(this, &PyArray_ULongLongDType);
1028+
1029+
PyArrayMethod_Spec *StringToULongLongCastSpec = get_cast_spec(
1030+
s2ulonglong_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1031+
s2ulonglong_dtypes, s2ulonglong_slots);
1032+
1033+
PyArray_DTypeMeta **ulonglong2s_dtypes =
1034+
get_dtypes(&PyArray_ULongLongDType, this);
1035+
1036+
PyArrayMethod_Spec *ULongLongToStringCastSpec = get_cast_spec(
1037+
ulonglong2s_name, NPY_UNSAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
1038+
ulonglong2s_dtypes, ulonglong2s_slots);
1039+
#endif
1040+
8821041
PyArray_DTypeMeta **s2ui64_dtypes = get_dtypes(this, &PyArray_UInt64DType);
8831042

8841043
PyArrayMethod_Spec *StringToUInt64CastSpec = get_cast_spec(
@@ -895,37 +1054,64 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
8951054

8961055
casts = malloc((num_casts + 1) * sizeof(PyArrayMethod_Spec *));
8971056

898-
casts[0] = ThisToThisCastSpec;
899-
casts[1] = UnicodeToStringCastSpec;
900-
casts[2] = StringToUnicodeCastSpec;
901-
casts[3] = StringToBoolCastSpec;
902-
casts[4] = BoolToStringCastSpec;
903-
casts[5] = StringToInt8CastSpec;
904-
casts[6] = Int8ToStringCastSpec;
905-
casts[7] = StringToInt16CastSpec;
906-
casts[8] = Int16ToStringCastSpec;
907-
casts[9] = StringToInt32CastSpec;
908-
casts[10] = Int32ToStringCastSpec;
909-
casts[11] = StringToInt64CastSpec;
910-
casts[12] = Int64ToStringCastSpec;
911-
casts[13] = StringToUInt8CastSpec;
912-
casts[14] = UInt8ToStringCastSpec;
913-
casts[15] = StringToUInt16CastSpec;
914-
casts[16] = UInt16ToStringCastSpec;
915-
casts[17] = StringToUInt32CastSpec;
916-
casts[18] = UInt32ToStringCastSpec;
917-
casts[19] = StringToUInt64CastSpec;
918-
casts[20] = UInt64ToStringCastSpec;
1057+
int cast_i = 0;
1058+
1059+
casts[cast_i++] = ThisToThisCastSpec;
1060+
casts[cast_i++] = UnicodeToStringCastSpec;
1061+
casts[cast_i++] = StringToUnicodeCastSpec;
1062+
casts[cast_i++] = StringToBoolCastSpec;
1063+
casts[cast_i++] = BoolToStringCastSpec;
1064+
casts[cast_i++] = StringToInt8CastSpec;
1065+
casts[cast_i++] = Int8ToStringCastSpec;
1066+
casts[cast_i++] = StringToInt16CastSpec;
1067+
casts[cast_i++] = Int16ToStringCastSpec;
1068+
casts[cast_i++] = StringToInt32CastSpec;
1069+
casts[cast_i++] = Int32ToStringCastSpec;
1070+
casts[cast_i++] = StringToInt64CastSpec;
1071+
casts[cast_i++] = Int64ToStringCastSpec;
1072+
casts[cast_i++] = StringToUInt8CastSpec;
1073+
casts[cast_i++] = UInt8ToStringCastSpec;
1074+
casts[cast_i++] = StringToUInt16CastSpec;
1075+
casts[cast_i++] = UInt16ToStringCastSpec;
1076+
casts[cast_i++] = StringToUInt32CastSpec;
1077+
casts[cast_i++] = UInt32ToStringCastSpec;
1078+
casts[cast_i++] = StringToUInt64CastSpec;
1079+
casts[cast_i++] = UInt64ToStringCastSpec;
1080+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
1081+
casts[cast_i++] = StringToByteCastSpec;
1082+
casts[cast_i++] = ByteToStringCastSpec;
1083+
casts[cast_i++] = StringToUByteCastSpec;
1084+
casts[cast_i++] = UByteToStringCastSpec;
1085+
#endif
1086+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
1087+
casts[cast_i++] = StringToShortCastSpec;
1088+
casts[cast_i++] = ShortToStringCastSpec;
1089+
casts[cast_i++] = StringToUShortCastSpec;
1090+
casts[cast_i++] = UShortToStringCastSpec;
1091+
#endif
1092+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
1093+
casts[cast_i++] = StringToIntCastSpec;
1094+
casts[cast_i++] = IntToStringCastSpec;
1095+
casts[cast_i++] = StringToUIntCastSpec;
1096+
casts[cast_i++] = UIntToStringCastSpec;
1097+
#endif
1098+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
1099+
casts[cast_i++] = StringToLongLongCastSpec;
1100+
casts[cast_i++] = LongLongToStringCastSpec;
1101+
casts[cast_i++] = StringToULongLongCastSpec;
1102+
casts[cast_i++] = ULongLongToStringCastSpec;
1103+
#endif
9191104
if (is_pandas) {
920-
casts[21] = ThisToOtherCastSpec;
921-
casts[22] = OtherToThisCastSpec;
922-
casts[23] = NULL;
1105+
casts[cast_i++] = ThisToOtherCastSpec;
1106+
casts[cast_i++] = OtherToThisCastSpec;
1107+
casts[cast_i++] = NULL;
9231108
}
9241109
else {
925-
casts[21] = NULL;
1110+
casts[cast_i++] = NULL;
9261111
}
9271112

9281113
assert(casts[num_casts] == NULL);
1114+
assert(cast_i == num_casts + 1);
9291115

9301116
return casts;
9311117
}

stringdtype/tests/test_stringdtype.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def test_cast_from_bool(dtype, strings, cast_answer):
376376

377377
@pytest.mark.parametrize("bitsize", [8, 16, 32, 64])
378378
@pytest.mark.parametrize("signed", [True, False])
379-
def test_integer_casts(dtype, bitsize, signed):
379+
def test_sized_integer_casts(dtype, bitsize, signed):
380380
idtype = f"int{bitsize}"
381381
if signed:
382382
inp = [-(2**p - 1) for p in reversed(range(bitsize - 1))]
@@ -398,6 +398,16 @@ def test_integer_casts(dtype, bitsize, signed):
398398
np.array(oob, dtype=dtype).astype(idtype)
399399

400400

401+
@pytest.mark.parametrize("typename", ["byte", "short", "int", "longlong"])
402+
@pytest.mark.parametrize("signed", ["", "u"])
403+
def test_unsized_integer_casts(dtype, typename, signed):
404+
idtype = f"{signed}{typename}"
405+
406+
inp = [1, 2, 3, 4]
407+
ainp = np.array(inp, dtype=idtype)
408+
np.testing.assert_array_equal(ainp, ainp.astype(dtype).astype(idtype))
409+
410+
401411
def test_take(dtype, string_list):
402412
sarr = np.array(string_list, dtype=dtype)
403413
out = np.empty(len(string_list), dtype=dtype)

0 commit comments

Comments
 (0)