Skip to content

Commit 0d234bd

Browse files
committed
implement same_value casting for numpy <-> quadtype
1 parent c4c1def commit 0d234bd

File tree

3 files changed

+96
-9
lines changed

3 files changed

+96
-9
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
22
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
3-
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
4-
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
3+
#define NPY_NO_DEPRECATED_API NPY_2_4_API_VERSION
4+
#define NPY_TARGET_VERSION NPY_2_4_API_VERSION
55
#define NO_IMPORT_ARRAY
66
#define NO_IMPORT_UFUNC
77

@@ -157,7 +157,7 @@ void_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *
157157
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
158158
npy_intp *view_offset)
159159
{
160-
PyErr_SetString(PyExc_TypeError,
160+
PyErr_SetString(PyExc_TypeError,
161161
"Void to QuadPrecision cast is not implemented");
162162
return (NPY_CASTING)-1;
163163
}
@@ -401,7 +401,7 @@ to_quad<long double>(long double x, QuadBackendType backend)
401401
}
402402

403403
template <typename T>
404-
static NPY_CASTING
404+
static int
405405
numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2],
406406
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
407407
npy_intp *view_offset)
@@ -419,7 +419,11 @@ numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta
419419
}
420420

421421
loop_descrs[0] = PyArray_GetDefaultDescr(dtypes[0]);
422+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
423+
return NPY_SAFE_CASTING | NPY_SAME_VALUE_CASTING_FLAG;
424+
#else
422425
return NPY_SAFE_CASTING;
426+
#endif
423427
}
424428

425429
template <typename T>
@@ -666,6 +670,28 @@ from_quad<long double>(quad_value x, QuadBackendType backend)
666670
}
667671
}
668672

673+
template <typename T>
674+
static inline int
675+
from_quad_checked(quad_value x, QuadBackendType backend, typename NpyType<T>::TYPE *ret) {
676+
*ret = from_quad<typename NpyType<T>::TYPE>(x, backend);
677+
quad_value check = to_quad<typename NpyType<T>::TYPE>(*ret, backend);
678+
if (backend == BACKEND_SLEEF) {
679+
if (check.sleef_value == x.sleef_value) {
680+
return 0;
681+
}
682+
}
683+
else {
684+
if (check.longdouble_value == x.longdouble_value) {
685+
return 0;
686+
}
687+
}
688+
NPY_ALLOW_C_API_DEF;
689+
NPY_ALLOW_C_API;
690+
PyErr_SetString(PyExc_ValueError, "could not cast 'same_value' to QuadType");
691+
NPY_DISABLE_C_API;
692+
return -1;
693+
}
694+
669695
template <typename T>
670696
static NPY_CASTING
671697
quad_to_numpy_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2],
@@ -685,6 +711,9 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
685711
npy_intp const dimensions[], npy_intp const strides[],
686712
void *NPY_UNUSED(auxdata))
687713
{
714+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
715+
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG);
716+
#endif
688717
npy_intp N = dimensions[0];
689718
char *in_ptr = data[0];
690719
char *out_ptr = data[1];
@@ -694,6 +723,24 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
694723

695724
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
696725

726+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
727+
if (same_value_casting) {
728+
while (N--) {
729+
quad_value in_val;
730+
memcpy(&in_val, in_ptr, elem_size);
731+
typename NpyType<T>::TYPE out_val;
732+
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) {
733+
return -1;
734+
}
735+
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE));
736+
737+
in_ptr += strides[0];
738+
out_ptr += strides[1];
739+
}
740+
} else {
741+
#else
742+
{
743+
#endif
697744
while (N--) {
698745
quad_value in_val;
699746
memcpy(&in_val, in_ptr, elem_size);
@@ -703,7 +750,7 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
703750

704751
in_ptr += strides[0];
705752
out_ptr += strides[1];
706-
}
753+
}}
707754
return 0;
708755
}
709756

@@ -716,10 +763,36 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
716763
npy_intp N = dimensions[0];
717764
char *in_ptr = data[0];
718765
char *out_ptr = data[1];
766+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
767+
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG);
768+
#endif
719769

720770
QuadPrecDTypeObject *quad_descr = (QuadPrecDTypeObject *)context->descriptors[0];
721771
QuadBackendType backend = quad_descr->backend;
722772

773+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
774+
if (same_value_casting) {
775+
while (N--) {
776+
quad_value in_val;
777+
if (backend == BACKEND_SLEEF) {
778+
in_val.sleef_value = *(Sleef_quad *)in_ptr;
779+
}
780+
else {
781+
in_val.longdouble_value = *(long double *)in_ptr;
782+
}
783+
784+
typename NpyType<T>::TYPE out_val;
785+
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) {
786+
return -1;
787+
}
788+
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val;
789+
790+
in_ptr += strides[0];
791+
out_ptr += strides[1];
792+
}} else {
793+
#else
794+
{
795+
#endif
723796
while (N--) {
724797
quad_value in_val;
725798
if (backend == BACKEND_SLEEF) {
@@ -734,7 +807,7 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
734807

735808
in_ptr += strides[0];
736809
out_ptr += strides[1];
737-
}
810+
}}
738811
return 0;
739812
}
740813

@@ -771,7 +844,11 @@ add_cast_from(PyArray_DTypeMeta *to)
771844
.name = "cast_QuadPrec_to_NumPy",
772845
.nin = 1,
773846
.nout = 1,
847+
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
848+
.casting = NPY_SAME_VALUE_CASTING,
849+
#else
774850
.casting = NPY_UNSAFE_CASTING,
851+
#endif
775852
.flags = NPY_METH_SUPPORTS_UNALIGNED,
776853
.dtypes = dtypes,
777854
.slots = slots,
@@ -904,4 +981,4 @@ free_casts(void)
904981
}
905982
}
906983
spec_count = 0;
907-
}
984+
}

quaddtype/numpy_quaddtype/src/quadblas_interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include <algorithm>
44

55
#ifndef DISABLE_QUADBLAS
6-
#include "../subprojects/qblas/include/quadblas/quadblas.hpp"
6+
#include "quadblas/quadblas.hpp"
77
#endif // DISABLE_QUADBLAS
88

99
extern "C" {
@@ -230,4 +230,4 @@ _quadblas_get_num_threads(void)
230230

231231
#endif // DISABLE_QUADBLAS
232232

233-
} // extern "C"
233+
} // extern "C"

quaddtype/tests/test_quaddtype.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def test_unsupported_astype(dtype):
7676
np.array(QuadPrecision(1)).astype(dtype, casting="unsafe")
7777

7878

79+
def test_same_value_cast():
80+
# This will fail if compiled with NPY_TARGET_VERSION NPY<2_4_API_VERSION
81+
a = np.arange(30, dtype=np.float32)
82+
# upcasting can never fail
83+
b = a.astype(QuadPrecision, casting='same_value')
84+
c = b.astype(np.float32, casting='same_value')
85+
assert np.all(c == a)
86+
with pytest.raises(ValueError, match="could not cast 'same_value'"):
87+
(b + 1e22).astype(np.float32, casting='same_value')
88+
7989
def test_basic_equality():
8090
assert QuadPrecision("12") == QuadPrecision(
8191
"12.0") == QuadPrecision("12.00")

0 commit comments

Comments
 (0)