Skip to content

Commit 24703bd

Browse files
committed
Use template magic
1 parent 1c2f449 commit 24703bd

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extern "C" {
2020
#include "casts.h"
2121
#include "dtype.h"
2222

23-
#define NUM_CASTS 29 // 14 to_casts + 14 from_casts + 1 quad_to_quad
23+
#define NUM_CASTS 31 // 15 to_casts + 15 from_casts + 1 quad_to_quad
2424

2525
static NPY_CASTING
2626
quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -150,15 +150,23 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
150150
return 0;
151151
}
152152

153+
// Template magic to ensure npy_bool and npy_ubyte do not alias in templates
154+
struct my_npy_bool {}
155+
156+
template<typename T>
157+
struct NpyType { typedef T TYPE; };
158+
template<>
159+
struct NpyType<my_npy_bool>{ typedef npy_bool TYPE; };
160+
153161
// Casting from other types to QuadDType
154162

155163
template <typename T>
156164
static inline quad_value
157-
to_quad(T x, QuadBackendType backend);
165+
to_quad(NpyType<T>::TYPE x, QuadBackendType backend);
158166

159167
template <>
160168
inline quad_value
161-
to_quad<npy_bool>(npy_bool x, QuadBackendType backend)
169+
to_quad<my_npy_bool>(npy_bool x, QuadBackendType backend)
162170
{
163171
quad_value result;
164172
if (backend == BACKEND_SLEEF) {
@@ -388,10 +396,10 @@ numpy_to_quad_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
388396
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
389397

390398
while (N--) {
391-
T in_val;
399+
NpyType<T>::TYPE in_val;
392400
quad_value out_val;
393401

394-
memcpy(&in_val, in_ptr, sizeof(T));
402+
memcpy(&in_val, in_ptr, sizeof(NpyType<T>::TYPE));
395403
out_val = to_quad<T>(in_val, backend);
396404
memcpy(out_ptr, &out_val, elem_size);
397405

@@ -415,7 +423,7 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
415423
QuadBackendType backend = descr_out->backend;
416424

417425
while (N--) {
418-
T in_val = *(T *)in_ptr;
426+
NpyType<T>::TYPE in_val = *(NpyType<T>::TYPE *)in_ptr;
419427
quad_value out_val = to_quad<T>(in_val, backend);
420428

421429
if (backend == BACKEND_SLEEF) {
@@ -434,12 +442,12 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
434442
// Casting from QuadDType to other types
435443

436444
template <typename T>
437-
static inline T
445+
static inline NpyType<T>::TYPE
438446
from_quad(quad_value x, QuadBackendType backend);
439447

440448
template <>
441449
inline npy_bool
442-
from_quad<npy_bool>(quad_value x, QuadBackendType backend)
450+
from_quad<my_npy_bool>(quad_value x, QuadBackendType backend)
443451
{
444452
if (backend == BACKEND_SLEEF) {
445453
return Sleef_cast_to_int64q1(x.sleef_value) != 0;
@@ -637,8 +645,8 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
637645
quad_value in_val;
638646
memcpy(&in_val, in_ptr, elem_size);
639647

640-
T out_val = from_quad<T>(in_val, backend);
641-
memcpy(out_ptr, &out_val, sizeof(T));
648+
NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
649+
memcpy(out_ptr, &out_val, sizeof(NpyType<T>::TYPE));
642650

643651
in_ptr += strides[0];
644652
out_ptr += strides[1];
@@ -668,8 +676,8 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
668676
in_val.longdouble_value = *(long double *)in_ptr;
669677
}
670678

671-
T out_val = from_quad<T>(in_val, backend);
672-
*(T *)(out_ptr) = out_val;
679+
NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
680+
*(NpyType<T>::TYPE *)(out_ptr) = out_val;
673681

674682
in_ptr += strides[0];
675683
out_ptr += strides[1];
@@ -765,7 +773,7 @@ init_casts_internal(void)
765773

766774
add_spec(quad2quad_spec);
767775

768-
add_cast_to<npy_bool>(&PyArray_BoolDType);
776+
add_cast_to<my_npy_bool>(&PyArray_BoolDType);
769777
add_cast_to<npy_byte>(&PyArray_ByteDType);
770778
add_cast_to<npy_ubyte>(&PyArray_UByteDType);
771779
add_cast_to<npy_short>(&PyArray_ShortDType);
@@ -780,7 +788,7 @@ init_casts_internal(void)
780788
add_cast_to<double>(&PyArray_DoubleDType);
781789
add_cast_to<long double>(&PyArray_LongDoubleDType);
782790

783-
add_cast_from<npy_bool>(&PyArray_BoolDType);
791+
add_cast_from<my_npy_bool>(&PyArray_BoolDType);
784792
add_cast_from<npy_byte>(&PyArray_ByteDType);
785793
add_cast_from<npy_ubyte>(&PyArray_UByteDType);
786794
add_cast_from<npy_short>(&PyArray_ShortDType);

0 commit comments

Comments
 (0)