Skip to content

Commit b9340df

Browse files
committed
Implement cast support for half
1 parent 24703bd commit b9340df

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

quaddtype/meson.build

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ incdir_numpy = run_command(py,
2323
check : true
2424
).stdout().strip()
2525

26+
npymath_path = incdir_numpy / '..' / 'lib'
27+
npymath_lib = cc.find_library('npymath', dirs: npymath_path)
28+
29+
dependencies = [sleef_dep, py_dep, npymath_lib]
30+
2631
# Add OpenMP dependency (optional, for threading)
2732
openmp_dep = dependency('openmp', required: false)
28-
dependencies = [sleef_dep, py_dep]
2933
if openmp_dep.found()
3034
dependencies += openmp_dep
3135
endif

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extern "C" {
99
#include <Python.h>
1010

1111
#include "numpy/arrayobject.h"
12+
#include "numpy/halffloat.h"
1213
#include "numpy/ndarraytypes.h"
1314
#include "numpy/dtype_api.h"
1415
}
@@ -20,7 +21,7 @@ extern "C" {
2021
#include "casts.h"
2122
#include "dtype.h"
2223

23-
#define NUM_CASTS 31 // 15 to_casts + 15 from_casts + 1 quad_to_quad
24+
#define NUM_CASTS 33 // 16 to_casts + 16 from_casts + 1 quad_to_quad
2425

2526
static NPY_CASTING
2627
quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -150,19 +151,21 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
150151
return 0;
151152
}
152153

153-
// Template magic to ensure npy_bool and npy_ubyte do not alias in templates
154+
// Template magic to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
154155
struct my_npy_bool {}
156+
struct my_npy_half {}
155157

156158
template<typename T>
157159
struct NpyType { typedef T TYPE; };
158160
template<>
159161
struct NpyType<my_npy_bool>{ typedef npy_bool TYPE; };
162+
struct NpyType<my_npy_half>{ typedef npy_half TYPE; };
160163

161164
// Casting from other types to QuadDType
162165

163166
template <typename T>
164167
static inline quad_value
165-
to_quad(NpyType<T>::TYPE x, QuadBackendType backend);
168+
to_quad(typename NpyType<T>::TYPE x, QuadBackendType backend);
166169

167170
template <>
168171
inline quad_value
@@ -317,6 +320,21 @@ to_quad<npy_ulonglong>(npy_ulonglong x, QuadBackendType backend)
317320
}
318321
return result;
319322
}
323+
324+
template <>
325+
inline quad_value
326+
to_quad<my_npy_half>(npy_half x, QuadBackendType backend)
327+
{
328+
quad_value result;
329+
if (backend == BACKEND_SLEEF) {
330+
result.sleef_value = Sleef_cast_from_doubleq1(npy_half_to_double(x));
331+
}
332+
else {
333+
result.longdouble_value = (long double)npy_half_to_double(x);
334+
}
335+
return result;
336+
}
337+
320338
template <>
321339
inline quad_value
322340
to_quad<float>(float x, QuadBackendType backend)
@@ -396,10 +414,10 @@ numpy_to_quad_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
396414
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
397415

398416
while (N--) {
399-
NpyType<T>::TYPE in_val;
417+
typename NpyType<T>::TYPE in_val;
400418
quad_value out_val;
401419

402-
memcpy(&in_val, in_ptr, sizeof(NpyType<T>::TYPE));
420+
memcpy(&in_val, in_ptr, sizeof(typename NpyType<T>::TYPE));
403421
out_val = to_quad<T>(in_val, backend);
404422
memcpy(out_ptr, &out_val, elem_size);
405423

@@ -423,7 +441,7 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
423441
QuadBackendType backend = descr_out->backend;
424442

425443
while (N--) {
426-
NpyType<T>::TYPE in_val = *(NpyType<T>::TYPE *)in_ptr;
444+
typename NpyType<T>::TYPE in_val = *(typename NpyType<T>::TYPE *)in_ptr;
427445
quad_value out_val = to_quad<T>(in_val, backend);
428446

429447
if (backend == BACKEND_SLEEF) {
@@ -442,7 +460,7 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
442460
// Casting from QuadDType to other types
443461

444462
template <typename T>
445-
static inline NpyType<T>::TYPE
463+
static inline typename NpyType<T>::TYPE
446464
from_quad(quad_value x, QuadBackendType backend);
447465

448466
template <>
@@ -577,6 +595,18 @@ from_quad<npy_ulonglong>(quad_value x, QuadBackendType backend)
577595
}
578596
}
579597

598+
template <>
599+
inline npy_half
600+
from_quad<my_npy_half>(quad_value x, QuadBackendType backend)
601+
{
602+
if (backend == BACKEND_SLEEF) {
603+
return npy_double_to_half(Sleef_cast_to_doubleq1(x.sleef_value));
604+
}
605+
else {
606+
return npy_double_to_half((double)x.longdouble_value);
607+
}
608+
}
609+
580610
template <>
581611
inline float
582612
from_quad<float>(quad_value x, QuadBackendType backend)
@@ -645,8 +675,8 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
645675
quad_value in_val;
646676
memcpy(&in_val, in_ptr, elem_size);
647677

648-
NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
649-
memcpy(out_ptr, &out_val, sizeof(NpyType<T>::TYPE));
678+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
679+
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE));
650680

651681
in_ptr += strides[0];
652682
out_ptr += strides[1];
@@ -676,8 +706,8 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
676706
in_val.longdouble_value = *(long double *)in_ptr;
677707
}
678708

679-
NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
680-
*(NpyType<T>::TYPE *)(out_ptr) = out_val;
709+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
710+
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val;
681711

682712
in_ptr += strides[0];
683713
out_ptr += strides[1];
@@ -784,6 +814,7 @@ init_casts_internal(void)
784814
add_cast_to<npy_ulong>(&PyArray_ULongDType);
785815
add_cast_to<npy_longlong>(&PyArray_LongLongDType);
786816
add_cast_to<npy_ulonglong>(&PyArray_ULongLongDType);
817+
add_cast_to<my_npy_half>(&PyArray_HalfDType);
787818
add_cast_to<float>(&PyArray_FloatDType);
788819
add_cast_to<double>(&PyArray_DoubleDType);
789820
add_cast_to<long double>(&PyArray_LongDoubleDType);
@@ -799,6 +830,7 @@ init_casts_internal(void)
799830
add_cast_from<npy_ulong>(&PyArray_ULongDType);
800831
add_cast_from<npy_longlong>(&PyArray_LongLongDType);
801832
add_cast_from<npy_ulonglong>(&PyArray_ULongLongDType);
833+
add_cast_from<my_npy_half>(&PyArray_HalfDType);
802834
add_cast_from<float>(&PyArray_FloatDType);
803835
add_cast_from<double>(&PyArray_DoubleDType);
804836
add_cast_from<long double>(&PyArray_LongDoubleDType);

quaddtype/tests/test_quaddtype.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def test_finfo_int_constant(name, value):
4141

4242
@pytest.mark.parametrize("dtype", ["bool", "byte", "int8", "ubyte", "uint8", "short", "int16", "ushort", "uint16", "int", "int32", "uint", "uint32", "long", "ulong", "longlong", "int64", "ulonglong", "uint64", "half", "float16", "float", "float32", "double", "float64", "longdouble"])
4343
def test_astype(dtype):
44-
if dtype in ("half", "float16"):
45-
pytest.xfail(f"{dtype} astype not yet supported")
46-
4744
orig = np.array(1, dtype=dtype)
4845
quad = orig.astype(QuadPrecDType, casting="safe")
4946
back = quad.astype(dtype, casting="unsafe")

0 commit comments

Comments
 (0)