Skip to content

Commit 34a36ab

Browse files
committed
Implement cast support for ubyte and half
Implement cast support for ubyte Use template magic to distinguish npy_bool and npy_half Implement cast support for half
1 parent b4a9429 commit 34a36ab

File tree

3 files changed

+98
-15
lines changed

3 files changed

+98
-15
lines changed

quaddtype/meson.build

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ incdir_numpy = run_command(py,
2323
check : true
2424
).stdout().strip()
2525

26+
npymath_path = incdir_numpy / '..' / 'lib'
27+
npymath_lib = c.find_library('npymath', dirs: npymath_path)
28+
29+
dependencies = [sleef_dep, py_dep, npymath_lib]
30+
2631
# Add OpenMP dependency (optional, for threading)
2732
openmp_dep = dependency('openmp', required: false)
28-
dependencies = [sleef_dep, py_dep]
2933
if openmp_dep.found()
3034
dependencies += openmp_dep
3135
endif

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extern "C" {
99
#include <Python.h>
1010

1111
#include "numpy/arrayobject.h"
12+
#include "numpy/halffloat.h"
1213
#include "numpy/ndarraytypes.h"
1314
#include "numpy/dtype_api.h"
1415
}
@@ -20,7 +21,7 @@ extern "C" {
2021
#include "casts.h"
2122
#include "dtype.h"
2223

23-
#define NUM_CASTS 29 // 14 to_casts + 14 from_casts + 1 quad_to_quad
24+
#define NUM_CASTS 33 // 16 to_casts + 16 from_casts + 1 quad_to_quad
2425

2526
static NPY_CASTING
2627
quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -150,15 +151,26 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
150151
return 0;
151152
}
152153

154+
// Template magic to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
155+
struct my_npy_bool {};
156+
struct my_npy_half {};
157+
158+
template<typename T>
159+
struct NpyType { typedef T TYPE; };
160+
template<>
161+
struct NpyType<my_npy_bool>{ typedef npy_bool TYPE; };
162+
template<>
163+
struct NpyType<my_npy_half>{ typedef npy_half TYPE; };
164+
153165
// Casting from other types to QuadDType
154166

155167
template <typename T>
156168
static inline quad_value
157-
to_quad(T x, QuadBackendType backend);
169+
to_quad(typename NpyType<T>::TYPE x, QuadBackendType backend);
158170

159171
template <>
160172
inline quad_value
161-
to_quad<npy_bool>(npy_bool x, QuadBackendType backend)
173+
to_quad<my_npy_bool>(npy_bool x, QuadBackendType backend)
162174
{
163175
quad_value result;
164176
if (backend == BACKEND_SLEEF) {
@@ -184,6 +196,20 @@ to_quad<npy_byte>(npy_byte x, QuadBackendType backend)
184196
return result;
185197
}
186198

199+
template <>
200+
inline quad_value
201+
to_quad<npy_ubyte>(npy_ubyte x, QuadBackendType backend)
202+
{
203+
quad_value result;
204+
if (backend == BACKEND_SLEEF) {
205+
result.sleef_value = Sleef_cast_from_uint64q1(x);
206+
}
207+
else {
208+
result.longdouble_value = (long double)x;
209+
}
210+
return result;
211+
}
212+
187213
template <>
188214
inline quad_value
189215
to_quad<npy_short>(npy_short x, QuadBackendType backend)
@@ -295,6 +321,21 @@ to_quad<npy_ulonglong>(npy_ulonglong x, QuadBackendType backend)
295321
}
296322
return result;
297323
}
324+
325+
template <>
326+
inline quad_value
327+
to_quad<my_npy_half>(npy_half x, QuadBackendType backend)
328+
{
329+
quad_value result;
330+
if (backend == BACKEND_SLEEF) {
331+
result.sleef_value = Sleef_cast_from_doubleq1(npy_half_to_double(x));
332+
}
333+
else {
334+
result.longdouble_value = (long double)npy_half_to_double(x);
335+
}
336+
return result;
337+
}
338+
298339
template <>
299340
inline quad_value
300341
to_quad<float>(float x, QuadBackendType backend)
@@ -374,10 +415,10 @@ numpy_to_quad_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
374415
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
375416

376417
while (N--) {
377-
T in_val;
418+
typename NpyType<T>::TYPE in_val;
378419
quad_value out_val;
379420

380-
memcpy(&in_val, in_ptr, sizeof(T));
421+
memcpy(&in_val, in_ptr, sizeof(typename NpyType<T>::TYPE));
381422
out_val = to_quad<T>(in_val, backend);
382423
memcpy(out_ptr, &out_val, elem_size);
383424

@@ -401,7 +442,7 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
401442
QuadBackendType backend = descr_out->backend;
402443

403444
while (N--) {
404-
T in_val = *(T *)in_ptr;
445+
typename NpyType<T>::TYPE in_val = *(typename NpyType<T>::TYPE *)in_ptr;
405446
quad_value out_val = to_quad<T>(in_val, backend);
406447

407448
if (backend == BACKEND_SLEEF) {
@@ -420,12 +461,12 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
420461
// Casting from QuadDType to other types
421462

422463
template <typename T>
423-
static inline T
464+
static inline typename NpyType<T>::TYPE
424465
from_quad(quad_value x, QuadBackendType backend);
425466

426467
template <>
427468
inline npy_bool
428-
from_quad<npy_bool>(quad_value x, QuadBackendType backend)
469+
from_quad<my_npy_bool>(quad_value x, QuadBackendType backend)
429470
{
430471
if (backend == BACKEND_SLEEF) {
431472
return Sleef_cast_to_int64q1(x.sleef_value) != 0;
@@ -447,6 +488,18 @@ from_quad<npy_byte>(quad_value x, QuadBackendType backend)
447488
}
448489
}
449490

491+
template <>
492+
inline npy_ubyte
493+
from_quad<npy_ubyte>(quad_value x, QuadBackendType backend)
494+
{
495+
if (backend == BACKEND_SLEEF) {
496+
return (npy_ubyte)Sleef_cast_to_uint64q1(x.sleef_value);
497+
}
498+
else {
499+
return (npy_ubyte)x.longdouble_value;
500+
}
501+
}
502+
450503
template <>
451504
inline npy_short
452505
from_quad<npy_short>(quad_value x, QuadBackendType backend)
@@ -543,6 +596,18 @@ from_quad<npy_ulonglong>(quad_value x, QuadBackendType backend)
543596
}
544597
}
545598

599+
template <>
600+
inline npy_half
601+
from_quad<my_npy_half>(quad_value x, QuadBackendType backend)
602+
{
603+
if (backend == BACKEND_SLEEF) {
604+
return npy_double_to_half(Sleef_cast_to_doubleq1(x.sleef_value));
605+
}
606+
else {
607+
return npy_double_to_half((double)x.longdouble_value);
608+
}
609+
}
610+
546611
template <>
547612
inline float
548613
from_quad<float>(quad_value x, QuadBackendType backend)
@@ -611,8 +676,8 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
611676
quad_value in_val;
612677
memcpy(&in_val, in_ptr, elem_size);
613678

614-
T out_val = from_quad<T>(in_val, backend);
615-
memcpy(out_ptr, &out_val, sizeof(T));
679+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
680+
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE));
616681

617682
in_ptr += strides[0];
618683
out_ptr += strides[1];
@@ -642,8 +707,8 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
642707
in_val.longdouble_value = *(long double *)in_ptr;
643708
}
644709

645-
T out_val = from_quad<T>(in_val, backend);
646-
*(T *)(out_ptr) = out_val;
710+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
711+
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val;
647712

648713
in_ptr += strides[0];
649714
out_ptr += strides[1];
@@ -739,8 +804,9 @@ init_casts_internal(void)
739804

740805
add_spec(quad2quad_spec);
741806

742-
add_cast_to<npy_bool>(&PyArray_BoolDType);
807+
add_cast_to<my_npy_bool>(&PyArray_BoolDType);
743808
add_cast_to<npy_byte>(&PyArray_ByteDType);
809+
add_cast_to<npy_ubyte>(&PyArray_UByteDType);
744810
add_cast_to<npy_short>(&PyArray_ShortDType);
745811
add_cast_to<npy_ushort>(&PyArray_UShortDType);
746812
add_cast_to<npy_int>(&PyArray_IntDType);
@@ -749,12 +815,14 @@ init_casts_internal(void)
749815
add_cast_to<npy_ulong>(&PyArray_ULongDType);
750816
add_cast_to<npy_longlong>(&PyArray_LongLongDType);
751817
add_cast_to<npy_ulonglong>(&PyArray_ULongLongDType);
818+
add_cast_to<my_npy_half>(&PyArray_HalfDType);
752819
add_cast_to<float>(&PyArray_FloatDType);
753820
add_cast_to<double>(&PyArray_DoubleDType);
754821
add_cast_to<long double>(&PyArray_LongDoubleDType);
755822

756-
add_cast_from<npy_bool>(&PyArray_BoolDType);
823+
add_cast_from<my_npy_bool>(&PyArray_BoolDType);
757824
add_cast_from<npy_byte>(&PyArray_ByteDType);
825+
add_cast_from<npy_ubyte>(&PyArray_UByteDType);
758826
add_cast_from<npy_short>(&PyArray_ShortDType);
759827
add_cast_from<npy_ushort>(&PyArray_UShortDType);
760828
add_cast_from<npy_int>(&PyArray_IntDType);
@@ -763,6 +831,7 @@ init_casts_internal(void)
763831
add_cast_from<npy_ulong>(&PyArray_ULongDType);
764832
add_cast_from<npy_longlong>(&PyArray_LongLongDType);
765833
add_cast_from<npy_ulonglong>(&PyArray_ULongLongDType);
834+
add_cast_from<my_npy_half>(&PyArray_HalfDType);
766835
add_cast_from<float>(&PyArray_FloatDType);
767836
add_cast_from<double>(&PyArray_DoubleDType);
768837
add_cast_from<long double>(&PyArray_LongDoubleDType);

quaddtype/tests/test_quaddtype.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def test_finfo_int_constant(name, value):
3939
assert getattr(numpy_quaddtype, name) == value
4040

4141

42+
@pytest.mark.parametrize("dtype", ["bool", "byte", "int8", "ubyte", "uint8", "short", "int16", "ushort", "uint16", "int", "int32", "uint", "uint32", "long", "ulong", "longlong", "int64", "ulonglong", "uint64", "half", "float16", "float", "float32", "double", "float64", "longdouble"])
43+
def test_astype(dtype):
44+
orig = np.array(1, dtype=dtype)
45+
quad = orig.astype(QuadPrecDType, casting="safe")
46+
back = quad.astype(dtype, casting="unsafe")
47+
48+
assert quad == 1
49+
assert back == orig
50+
51+
4252
def test_basic_equality():
4353
assert QuadPrecision("12") == QuadPrecision(
4454
"12.0") == QuadPrecision("12.00")

0 commit comments

Comments
 (0)