@@ -86,20 +86,36 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
8686 return zip_common (ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
8787 }
8888
89- #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR ) \
90- namespace ops { \
91- template <typename T> \
92- struct NAME { \
93- KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
94- return ops::cast<decltype (EXPR), T> {}(EXPR); \
95- } \
96- }; \
97- } \
98- \
89+ #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR, EXPR_F64, EXPR_F32 ) \
90+ namespace ops { \
91+ template <typename T, typename = void > \
92+ struct NAME { \
93+ KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
94+ return ops::cast<decltype (EXPR), T> {}(EXPR); \
95+ } \
96+ }; \
97+ \
98+ template <> \
99+ struct NAME <double > { \
100+ KERNEL_FLOAT_INLINE double operator ()(double left, double right) { \
101+ return ops::cast<decltype (EXPR_F64), double > {}(EXPR_F64); \
102+ } \
103+ }; \
104+ \
105+ template <typename T> \
106+ struct NAME <T, enable_if_t <detail::allow_float_fallback<T>::value>> { \
107+ KERNEL_FLOAT_INLINE T operator ()(T left_, T right_) { \
108+ float left = ops::cast<T, float > {}(left_); \
109+ float right = ops::cast<T, float > {}(right_); \
110+ return ops::cast<decltype (EXPR_F32), T> {}(EXPR_F32); \
111+ } \
112+ }; \
113+ } \
114+ \
99115 KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME)
100116
101- #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
102- KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right) \
117+ #define KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (NAME, OP, EXPR_F64, EXPR_F32 ) \
118+ KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right, EXPR_F64, EXPR_F32) \
103119 \
104120 template <typename L, typename R, typename C = promote_t <L, R>, typename E1 , typename E2 > \
105121 KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, vector<L, E1 >, vector<R, E2 >> operator OP ( \
@@ -120,11 +136,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
120136 return zip_common (ops::NAME<C> {}, left, right); \
121137 }
122138
139+ #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
140+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (NAME, OP, left OP right, left OP right)
141+
123142KERNEL_FLOAT_DEFINE_BINARY_OP (add, +)
124143KERNEL_FLOAT_DEFINE_BINARY_OP (subtract, -)
125144KERNEL_FLOAT_DEFINE_BINARY_OP (divide, /)
126145KERNEL_FLOAT_DEFINE_BINARY_OP (multiply, *)
127- KERNEL_FLOAT_DEFINE_BINARY_OP (modulo, %)
146+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (modulo, %, ::fmod(left, right), ::fmodf(left, right) )
128147
129148KERNEL_FLOAT_DEFINE_BINARY_OP (equal_to, ==)
130149KERNEL_FLOAT_DEFINE_BINARY_OP(not_equal_to, !=)
@@ -133,9 +152,11 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(less_equal, <=)
133152KERNEL_FLOAT_DEFINE_BINARY_OP(greater, >)
134153KERNEL_FLOAT_DEFINE_BINARY_OP(greater_equal, >=)
135154
136- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_and, &)
137- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_or, |)
138- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_xor, ^)
155+ // clang-format off
156+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_and, &, bool (left) && bool(right), bool(left) && bool(right))
157+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_or, |, bool (left) | bool(right), bool(left) | bool(right))
158+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool (left) ^ bool(right), bool(left) ^ bool(right))
159+ // clang-format on
139160
140161// clang-format off
141162template<template<typename> typename F, typename T, typename E, typename R>
@@ -247,56 +268,40 @@ KERNEL_FLOAT_DEFINE_BINARY_MATH(nextafter)
247268KERNEL_FLOAT_DEFINE_BINARY_MATH(pow)
248269KERNEL_FLOAT_DEFINE_BINARY_MATH(remainder)
249270
250- KERNEL_FLOAT_DEFINE_BINARY (hypot, (ops::sqrt<T>()(left * left + right * right)))
251- KERNEL_FLOAT_DEFINE_BINARY (rhypot, (T(1 ) / ops::hypot<T>()(left, right)))
252-
253- namespace ops {
254- template <>
255- struct hypot <double > {
256- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
257- return ::hypot (left, right);
258- };
259- };
260-
261- template <>
262- struct hypot <float > {
263- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
264- return ::hypotf (left, right);
265- };
266- };
271+ KERNEL_FLOAT_DEFINE_BINARY(
272+ hypot,
273+ ops::sqrt<T>()(left* left + right * right),
274+ ::hypot(left, right),
275+ ::hypotf(left, right))
267276
268- // rhypot is only support on the GPU
269277#if KERNEL_FLOAT_IS_DEVICE
270- template <>
271- struct rhypot <double > {
272- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
273- return ::rhypot (left, right);
274- };
275- };
276-
277- template <>
278- struct rhypot <float > {
279- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
280- return ::rhypotf (left, right);
281- };
282- };
278+ KERNEL_FLOAT_DEFINE_BINARY (
279+ rhypot,
280+ (T(1 ) / ops::hypot<T>()(left, right)),
281+ ::rhypot(left, right),
282+ ::rhypotf(left, right))
283+ #else
284+ KERNEL_FLOAT_DEFINE_BINARY (
285+ rhypot,
286+ (T(1 ) / ops::hypot<T>()(left, right)),
287+ (double (1 ) / ::hypot(left, right)),
288+ (float (1 ) / ::hypotf(left, right)))
283289#endif
284- }; // namespace ops
285290
286291#if KERNEL_FLOAT_IS_DEVICE
287- #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
288- KERNEL_FLOAT_DEFINE_BINARY (FUN_NAME, ops::OP_NAME<T> {}(left, right)) \
289- namespace ops { \
290- template <> \
291- struct OP_NAME <float > { \
292- KERNEL_FLOAT_INLINE float operator ()(float left, float right) { \
293- return FLOAT_FUN (left, right); \
294- } \
295- }; \
296- }
292+ #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
293+ KERNEL_FLOAT_DEFINE_BINARY ( \
294+ FUN_NAME, \
295+ ops::OP_NAME<T> {}(left, right), \
296+ ops::OP_NAME<double> {}(left, right), \
297+ ops::OP_NAME<float > {}(left, right))
297298#else
298299#define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
299- KERNEL_FLOAT_DEFINE_BINARY (FUN_NAME, ops::OP_NAME<T> {}(left, right))
300+ KERNEL_FLOAT_DEFINE_BINARY ( \
301+ FUN_NAME, \
302+ ops::OP_NAME<T> {}(left, right), \
303+ ops::OP_NAME<double> {}(left, right), \
304+ ops::OP_NAME<float > {}(left, right))
300305#endif
301306
302307KERNEL_FLOAT_DEFINE_BINARY_FAST (fast_div, divide, __fdividef)
@@ -316,48 +321,6 @@ struct multiply<bool> {
316321 return left && right;
317322 }
318323};
319-
320- template <>
321- struct bit_and <float > {
322- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
323- return float (bool (left) && bool (right));
324- }
325- };
326-
327- template <>
328- struct bit_or <float > {
329- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
330- return float (bool (left) || bool (right));
331- }
332- };
333-
334- template <>
335- struct bit_xor <float > {
336- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
337- return float (bool (left) ^ bool (right));
338- }
339- };
340-
341- template <>
342- struct bit_and <double > {
343- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
344- return double (bool (left) && bool (right));
345- }
346- };
347-
348- template <>
349- struct bit_or <double > {
350- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
351- return double (bool (left) || bool (right));
352- }
353- };
354-
355- template <>
356- struct bit_xor <double > {
357- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
358- return double (bool (left) ^ bool (right));
359- }
360- };
361324}; // namespace ops
362325
363326namespace detail {
0 commit comments