@@ -708,6 +708,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
708708
709709/************* STRIDED CASTING SPECIALIZED FUNCTIONS *************/
710710
711+ #if defined(NPY_HAVE_NEON_FP16)
712+ #define EMULATED_FP16 0
713+ #define NATIVE_FP16 1
714+ typedef _Float16 _npy_half;
715+ #else
716+ #define EMULATED_FP16 1
717+ #define NATIVE_FP16 0
718+ typedef npy_half _npy_half;
719+ #endif
720+
711721/**begin repeat
712722 *
713723 * #NAME1 = BOOL,
@@ -723,15 +733,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
723733 * #type1 = npy_bool,
724734 * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
725735 * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
726- * npy_half , npy_float, npy_double, npy_longdouble,
736+ * _npy_half , npy_float, npy_double, npy_longdouble,
727737 * npy_cfloat, npy_cdouble, npy_clongdouble#
728738 * #rtype1 = npy_bool,
729739 * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
730740 * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
731- * npy_half , npy_float, npy_double, npy_longdouble,
741+ * _npy_half , npy_float, npy_double, npy_longdouble,
732742 * npy_float, npy_double, npy_longdouble#
733743 * #is_bool1 = 1, 0*17#
734- * #is_half1 = 0*11, 1, 0*6#
744+ * #is_emu_half1 = 0*11, EMULATED_FP16, 0*6#
745+ * #is_native_half1 = 0*11, NATIVE_FP16, 0*6#
735746 * #is_float1 = 0*12, 1, 0, 0, 1, 0, 0#
736747 * #is_double1 = 0*13, 1, 0, 0, 1, 0#
737748 * #is_complex1 = 0*15, 1*3#
@@ -752,15 +763,16 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
752763 * #type2 = npy_bool,
753764 * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
754765 * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
755- * npy_half , npy_float, npy_double, npy_longdouble,
766+ * _npy_half , npy_float, npy_double, npy_longdouble,
756767 * npy_cfloat, npy_cdouble, npy_clongdouble#
757768 * #rtype2 = npy_bool,
758769 * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
759770 * npy_byte, npy_short, npy_int, npy_long, npy_longlong,
760- * npy_half , npy_float, npy_double, npy_longdouble,
771+ * _npy_half , npy_float, npy_double, npy_longdouble,
761772 * npy_float, npy_double, npy_longdouble#
762773 * #is_bool2 = 1, 0*17#
763- * #is_half2 = 0*11, 1, 0*6#
774+ * #is_emu_half2 = 0*11, EMULATED_FP16, 0*6#
775+ * #is_native_half2 = 0*11, NATIVE_FP16, 0*6#
764776 * #is_float2 = 0*12, 1, 0, 0, 1, 0, 0#
765777 * #is_double2 = 0*13, 1, 0, 0, 1, 0#
766778 * #is_complex2 = 0*15, 1*3#
@@ -774,8 +786,8 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
774786
775787#if !(NPY_USE_UNALIGNED_ACCESS && !@aligned@)
776788
777- /* For half types, don't use actual double/float types in conversion */
778- #if @is_half1 @ || @is_half2 @
789+ /* For emulated half types, don't use actual double/float types in conversion */
790+ #if @is_emu_half1 @ || @is_emu_half2 @
779791
780792# if @is_float1@
781793# define _TYPE1 npy_uint32
@@ -801,27 +813,27 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
801813#endif
802814
803815/* Determine an appropriate casting conversion function */
804- #if @is_half1 @
816+ #if @is_emu_half1 @
805817
806818# if @is_float2@
807819# define _CONVERT_FN(x) npy_halfbits_to_floatbits(x)
808820# elif @is_double2@
809821# define _CONVERT_FN(x) npy_halfbits_to_doublebits(x)
810- # elif @is_half2 @
822+ # elif @is_emu_half2 @
811823# define _CONVERT_FN(x) (x)
812824# elif @is_bool2@
813825# define _CONVERT_FN(x) ((npy_bool)!npy_half_iszero(x))
814826# else
815827# define _CONVERT_FN(x) ((_TYPE2)npy_half_to_float(x))
816828# endif
817829
818- #elif @is_half2 @
830+ #elif @is_emu_half2 @
819831
820832# if @is_float1@
821833# define _CONVERT_FN(x) npy_floatbits_to_halfbits(x)
822834# elif @is_double1@
823835# define _CONVERT_FN(x) npy_doublebits_to_halfbits(x)
824- # elif @is_half1 @
836+ # elif @is_emu_half1 @
825837# define _CONVERT_FN(x) (x)
826838# elif @is_bool1@
827839# define _CONVERT_FN(x) npy_float_to_half((float)(x!=0))
@@ -839,7 +851,29 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop *
839851
840852#endif
841853
842- static NPY_GCC_OPT_3 int
854+ // Enable auto-vectorization for floating point casts with clang
855+ #if @is_native_half1@ || @is_float1@ || @is_double1@
856+ #if @is_native_half2@ || @is_float2@ || @is_double2@
857+ #if defined(__clang__) && !defined(__EMSCRIPTEN__)
858+ #if __clang_major__ >= 12
859+ _Pragma("clang fp exceptions(ignore)")
860+ #endif
861+ #endif
862+ #endif
863+ #endif
864+
865+ // Work around GCC bug for double->half casts. For SVE and
866+ // OPT_LEVEL > 1, it implements this as double->single->half
867+ // which is incorrect as it introduces double rounding with
868+ // narrowing casts.
869+ #if (@is_double1@ && @is_native_half2@) && \
870+ defined(NPY_HAVE_SVE) && defined(__GNUC__)
871+ #define GCC_CAST_OPT_LEVEL __attribute__((optimize("O1")))
872+ #else
873+ #define GCC_CAST_OPT_LEVEL NPY_GCC_OPT_3
874+ #endif
875+
876+ static GCC_CAST_OPT_LEVEL int
843877@prefix@_cast_@name1@_to_@name2@(
844878 PyArrayMethod_Context *context, char *const *args,
845879 const npy_intp *dimensions, const npy_intp *strides,
@@ -933,6 +967,17 @@ static NPY_GCC_OPT_3 int
933967 return 0;
934968}
935969
970+ #if @is_native_half1@ || @is_float1@ || @is_double1@
971+ #if @is_native_half2@ || @is_float2@ || @is_double2@
972+ #if defined(__clang__) && !defined(__EMSCRIPTEN__)
973+ #if __clang_major__ >= 12
974+ _Pragma("clang fp exceptions(strict)")
975+ #endif
976+ #endif
977+ #endif
978+ #endif
979+
980+ #undef GCC_CAST_OPT_LEVEL
936981#undef _CONVERT_FN
937982#undef _TYPE2
938983#undef _TYPE1
0 commit comments