Skip to content

Commit da0a46b

Browse files
committed
Add tests for reductions
1 parent 7acff4c commit da0a46b

File tree

6 files changed

+589
-258
lines changed

6 files changed

+589
-258
lines changed

include/kernel_float/binops.h

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,46 @@ KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_xor, ^=)
172172
KERNEL_FLOAT_DEFINE_BINARY_FUN(min)
173173
KERNEL_FLOAT_DEFINE_BINARY_FUN(max)
174174
KERNEL_FLOAT_DEFINE_BINARY_FUN(copysign)
175-
KERNEL_FLOAT_DEFINE_BINARY_FUN(hypot)
176175
KERNEL_FLOAT_DEFINE_BINARY_FUN(modf)
177176
KERNEL_FLOAT_DEFINE_BINARY_FUN(nextafter)
178177
KERNEL_FLOAT_DEFINE_BINARY_FUN(pow)
179178
KERNEL_FLOAT_DEFINE_BINARY_FUN(remainder)
180179

181-
#if KERNEL_FLOAT_CUDA_DEVICE
182-
KERNEL_FLOAT_DEFINE_BINARY_FUN(rhypot)
180+
KERNEL_FLOAT_DEFINE_BINARY(hypot, (ops::sqrt<T>()(left * left + right * right)))
181+
KERNEL_FLOAT_DEFINE_BINARY(rhypot, (T(1) / ops::hypot<T>()(left, right)))
182+
183+
namespace ops {
184+
template<>
185+
struct hypot<double> {
186+
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
187+
return ::hypot(left, right);
188+
};
189+
};
190+
191+
template<>
192+
struct hypot<float> {
193+
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
194+
return ::hypotf(left, right);
195+
};
196+
};
197+
198+
// rhypot is only support on the GPU
199+
#if KERNEL_FLOAT_IS_DEVICE
200+
template<>
201+
struct rhypot<double> {
202+
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
203+
return ::rhypot(left, right);
204+
};
205+
};
206+
207+
template<>
208+
struct rhypot<float> {
209+
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
210+
return ::rhypotf(left, right);
211+
};
212+
};
183213
#endif
214+
}; // namespace ops
184215

185216
#if KERNEL_FLOAT_IS_DEVICE
186217
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \

include/kernel_float/fp16.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2)
199199
KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div)
200200
KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2)
201201
KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2)
202-
203202
KERNEL_FLOAT_FP16_BINARY_FUN(fast_div, __hdiv, __h2div)
204203

205204
KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2)

include/kernel_float/reduce.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,14 @@ template<typename T, size_t N>
144144
struct dot_impl {
145145
KERNEL_FLOAT_INLINE
146146
static T call(const vector_storage<T, N>& left, const vector_storage<T, N>& right) {
147-
return sum(zip(ops::multiply<T> {}, left, right));
147+
vector_storage<T, N> intermediate;
148+
detail::apply_impl<ops::multiply<T>, N, T, T, T>::call(
149+
ops::multiply<T>(),
150+
intermediate.data(),
151+
left.data(),
152+
right.data());
153+
154+
return detail::reduce_impl<ops::add<T>, N, T>::call(ops::add<T>(), intermediate.data());
148155
}
149156
};
150157
} // namespace detail
@@ -197,25 +204,25 @@ template<typename T>
197204
struct magnitude_impl<T, 2> {
198205
KERNEL_FLOAT_INLINE
199206
static T call(const vector_storage<T, 2>& input) {
200-
return ops::hypot<T> {}(input.data()[0], input.data()[1]);
207+
return ops::hypot<T>()(input.data()[0], input.data()[1]);
201208
}
202209
};
203210

204-
// The 3-argument overload of hypot is only available from C++17
205-
#ifdef __cpp_lib_hypot
211+
// The 3-argument overload of hypot is only available on host from C++17
212+
#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST
206213
template<>
207214
struct magnitude_impl<float, 3> {
208215
KERNEL_FLOAT_INLINE
209216
static float call(const vector_storage<float, 3>& input) {
210-
return std::hypot(input.data()[0], input.data()[1], input.data()[2]);
217+
return ::hypot(input.data()[0], input.data()[1], input.data()[2]);
211218
}
212219
};
213220

214221
template<>
215222
struct magnitude_impl<double, 3> {
216223
KERNEL_FLOAT_INLINE
217224
static float call(const vector_storage<double, 3>& input) {
218-
return std::hypot(input.data()[0], input.data()[1], input.data()[2]);
225+
return ::hypot(input.data()[0], input.data()[1], input.data()[2]);
219226
}
220227
};
221228
#endif

0 commit comments

Comments
 (0)