Skip to content

Commit d398884

Browse files
committed
Test casting for all dtypes
1 parent c7bbac1 commit d398884

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,17 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
151151
return 0;
152152
}
153153

154-
// Template magic to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
155-
struct my_npy_bool {};
156-
struct my_npy_half {};
154+
// Tag dispatching to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
155+
// see e.g. https://stackoverflow.com/q/32522279
156+
struct spec_npy_bool {};
157+
struct spec_npy_half {};
157158

158159
template<typename T>
159160
struct NpyType { typedef T TYPE; };
160161
template<>
161-
struct NpyType<my_npy_bool>{ typedef npy_bool TYPE; };
162+
struct NpyType<spec_npy_bool>{ typedef npy_bool TYPE; };
162163
template<>
163-
struct NpyType<my_npy_half>{ typedef npy_half TYPE; };
164+
struct NpyType<spec_npy_half>{ typedef npy_half TYPE; };
164165

165166
// Casting from other types to QuadDType
166167

@@ -170,7 +171,7 @@ to_quad(typename NpyType<T>::TYPE x, QuadBackendType backend);
170171

171172
template <>
172173
inline quad_value
173-
to_quad<my_npy_bool>(npy_bool x, QuadBackendType backend)
174+
to_quad<spec_npy_bool>(npy_bool x, QuadBackendType backend)
174175
{
175176
quad_value result;
176177
if (backend == BACKEND_SLEEF) {
@@ -324,7 +325,7 @@ to_quad<npy_ulonglong>(npy_ulonglong x, QuadBackendType backend)
324325

325326
template <>
326327
inline quad_value
327-
to_quad<my_npy_half>(npy_half x, QuadBackendType backend)
328+
to_quad<spec_npy_half>(npy_half x, QuadBackendType backend)
328329
{
329330
quad_value result;
330331
if (backend == BACKEND_SLEEF) {
@@ -466,7 +467,7 @@ from_quad(quad_value x, QuadBackendType backend);
466467

467468
template <>
468469
inline npy_bool
469-
from_quad<my_npy_bool>(quad_value x, QuadBackendType backend)
470+
from_quad<spec_npy_bool>(quad_value x, QuadBackendType backend)
470471
{
471472
if (backend == BACKEND_SLEEF) {
472473
return Sleef_cast_to_int64q1(x.sleef_value) != 0;
@@ -598,7 +599,7 @@ from_quad<npy_ulonglong>(quad_value x, QuadBackendType backend)
598599

599600
template <>
600601
inline npy_half
601-
from_quad<my_npy_half>(quad_value x, QuadBackendType backend)
602+
from_quad<spec_npy_half>(quad_value x, QuadBackendType backend)
602603
{
603604
if (backend == BACKEND_SLEEF) {
604605
return npy_double_to_half(Sleef_cast_to_doubleq1(x.sleef_value));
@@ -804,7 +805,7 @@ init_casts_internal(void)
804805

805806
add_spec(quad2quad_spec);
806807

807-
add_cast_to<my_npy_bool>(&PyArray_BoolDType);
808+
add_cast_to<spec_npy_bool>(&PyArray_BoolDType);
808809
add_cast_to<npy_byte>(&PyArray_ByteDType);
809810
add_cast_to<npy_ubyte>(&PyArray_UByteDType);
810811
add_cast_to<npy_short>(&PyArray_ShortDType);
@@ -815,12 +816,12 @@ init_casts_internal(void)
815816
add_cast_to<npy_ulong>(&PyArray_ULongDType);
816817
add_cast_to<npy_longlong>(&PyArray_LongLongDType);
817818
add_cast_to<npy_ulonglong>(&PyArray_ULongLongDType);
818-
add_cast_to<my_npy_half>(&PyArray_HalfDType);
819+
add_cast_to<spec_npy_half>(&PyArray_HalfDType);
819820
add_cast_to<float>(&PyArray_FloatDType);
820821
add_cast_to<double>(&PyArray_DoubleDType);
821822
add_cast_to<long double>(&PyArray_LongDoubleDType);
822823

823-
add_cast_from<my_npy_bool>(&PyArray_BoolDType);
824+
add_cast_from<spec_npy_bool>(&PyArray_BoolDType);
824825
add_cast_from<npy_byte>(&PyArray_ByteDType);
825826
add_cast_from<npy_ubyte>(&PyArray_UByteDType);
826827
add_cast_from<npy_short>(&PyArray_ShortDType);
@@ -831,7 +832,7 @@ init_casts_internal(void)
831832
add_cast_from<npy_ulong>(&PyArray_ULongDType);
832833
add_cast_from<npy_longlong>(&PyArray_LongLongDType);
833834
add_cast_from<npy_ulonglong>(&PyArray_ULongLongDType);
834-
add_cast_from<my_npy_half>(&PyArray_HalfDType);
835+
add_cast_from<spec_npy_half>(&PyArray_HalfDType);
835836
add_cast_from<float>(&PyArray_FloatDType);
836837
add_cast_from<double>(&PyArray_DoubleDType);
837838
add_cast_from<long double>(&PyArray_LongDoubleDType);

quaddtype/tests/test_quaddtype.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,22 @@ def test_finfo_int_constant(name, value):
3939
assert getattr(numpy_quaddtype, name) == value
4040

4141

42-
@pytest.mark.parametrize("dtype", ["bool", "byte", "int8", "ubyte", "uint8", "short", "int16", "ushort", "uint16", "int", "int32", "uint", "uint32", "long", "ulong", "longlong", "int64", "ulonglong", "uint64", "half", "float16", "float", "float32", "double", "float64", "longdouble"])
43-
def test_astype(dtype):
42+
@pytest.mark.parametrize("dtype", [
43+
"bool",
44+
"byte", "int8", "ubyte", "uint8",
45+
"short", "int16", "ushort", "uint16",
46+
"int", "int32", "uint", "uint32",
47+
"long", "ulong",
48+
"longlong", "int64", "ulonglong", "uint64",
49+
"half", "float16",
50+
"float", "float32",
51+
"double", "float64",
52+
"longdouble", "float96", "float128",
53+
])
54+
def test_supported_astype(dtype):
55+
if dtype in ("float96", "float128") and getattr(np, dtype, None) is None:
56+
pytest.skip(f"{dtype} is unsupported on the current platform")
57+
4458
orig = np.array(1, dtype=dtype)
4559
quad = orig.astype(QuadPrecDType, casting="safe")
4660
back = quad.astype(dtype, casting="unsafe")
@@ -49,6 +63,17 @@ def test_astype(dtype):
4963
assert back == orig
5064

5165

66+
@pytest.mark.parametrize("dtype", ["S10", "U10", "T", "V10", "datetime64[ms]", "timedelta64[ms]"])
67+
def test_unsupported_astype(dtype):
68+
val = 1 if dtype != "V10" else b"1"
69+
70+
with pytest.raises(TypeError if dtype != "V10" else ValueError, match="cast"):
71+
np.array(val, dtype=dtype).astype(QuadPrecDType, casting="unsafe")
72+
73+
with pytest.raises(TypeError if dtype != "V10" else ValueError, match="cast"):
74+
np.array(QuadPrecision(1)).astype(dtype, casting="unsafe")
75+
76+
5277
def test_basic_equality():
5378
assert QuadPrecision("12") == QuadPrecision(
5479
"12.0") == QuadPrecision("12.00")

0 commit comments

Comments
 (0)