Skip to content

Commit 4278106

Browse files
committed
Generalize implementation of dot_impl to remove fp16 specialization
1 parent 3b349be commit 4278106

File tree

5 files changed

+244
-362
lines changed

5 files changed

+244
-362
lines changed

include/kernel_float/bf16.h

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -226,55 +226,6 @@ using bfloat16 = __nv_bfloat16;
226226
//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16)
227227
//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16)
228228

229-
#if KERNEL_FLOAT_CUDA_ARCH >= 800
230-
namespace detail {
231-
template<>
232-
struct dot_impl<__nv_bfloat16, 0> {
233-
KERNEL_FLOAT_INLINE
234-
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
235-
return __nv_bfloat16(0);
236-
}
237-
};
238-
239-
template<>
240-
struct dot_impl<__nv_bfloat16, 1> {
241-
KERNEL_FLOAT_INLINE
242-
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
243-
return __hmul(left[0], right[0]);
244-
}
245-
};
246-
247-
template<size_t N>
248-
struct dot_impl<__nv_bfloat16, N> {
249-
static_assert(N >= 2, "internal error");
250-
251-
KERNEL_FLOAT_INLINE
252-
static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) {
253-
__nv_bfloat162 first_a = {left[0], left[1]};
254-
__nv_bfloat162 first_b = {right[0], right[1]};
255-
__nv_bfloat162 accum = __hmul2(first_a, first_b);
256-
257-
#pragma unroll
258-
for (size_t i = 2; i + 1 < N; i += 2) {
259-
__nv_bfloat162 a = {left[i], left[i + 1]};
260-
__nv_bfloat162 b = {right[i], right[i + 1]};
261-
accum = __hfma2(a, b, accum);
262-
}
263-
264-
__nv_bfloat16 result = __hadd(accum.x, accum.y);
265-
266-
if (N % 2 != 0) {
267-
__nv_bfloat16 a = left[N - 1];
268-
__nv_bfloat16 b = right[N - 1];
269-
result = __hfma(a, b, result);
270-
}
271-
272-
return result;
273-
}
274-
};
275-
} // namespace detail
276-
#endif
277-
278229
} // namespace kernel_float
279230

280231
#if KERNEL_FLOAT_FP16_AVAILABLE

include/kernel_float/fp16.h

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -174,55 +174,6 @@ using half = __half;
174174
//KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
175175
//KERNEL_FLOAT_TYPE_ALIAS(f16x, __half)
176176

177-
#if KERNEL_FLOAT_IS_DEVICE
178-
namespace detail {
179-
template<>
180-
struct dot_impl<__half, 0> {
181-
KERNEL_FLOAT_INLINE
182-
static __half call(const __half* left, const __half* right) {
183-
return __half(0);
184-
}
185-
};
186-
187-
template<>
188-
struct dot_impl<__half, 1> {
189-
KERNEL_FLOAT_INLINE
190-
static __half call(const __half* left, const __half* right) {
191-
return __hmul(left[0], right[0]);
192-
}
193-
};
194-
195-
template<size_t N>
196-
struct dot_impl<__half, N> {
197-
static_assert(N >= 2, "internal error");
198-
199-
KERNEL_FLOAT_INLINE
200-
static __half call(const __half* left, const __half* right) {
201-
__half2 first_a = {left[0], left[1]};
202-
__half2 first_b = {right[0], right[1]};
203-
__half2 accum = __hmul2(first_a, first_b);
204-
205-
#pragma unroll
206-
for (size_t i = 2; i + 2 <= N; i += 2) {
207-
__half2 a = {left[i], left[i + 1]};
208-
__half2 b = {right[i], right[i + 1]};
209-
accum = __hfma2(a, b, accum);
210-
}
211-
212-
__half result = __hadd(accum.x, accum.y);
213-
214-
if (N % 2 != 0) {
215-
__half a = left[N - 1];
216-
__half b = right[N - 1];
217-
result = __hfma(a, b, result);
218-
}
219-
220-
return result;
221-
}
222-
};
223-
} // namespace detail
224-
#endif
225-
226177
} // namespace kernel_float
227178

228179
#endif

include/kernel_float/reduce.h

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define KERNEL_FLOAT_REDUCE_H
33

44
#include "binops.h"
5+
#include "triops.h"
56

67
namespace kernel_float {
78
namespace detail {
@@ -177,14 +178,38 @@ template<typename T, size_t N>
177178
struct dot_impl {
178179
KERNEL_FLOAT_INLINE
179180
static T call(const T* left, const T* right) {
180-
vector_storage<T, N> intermediate;
181-
detail::map_impl<ops::multiply<T>, N, T, T, T>::call(
182-
ops::multiply<T>(),
183-
intermediate.data(),
184-
left,
185-
right);
186-
187-
return detail::reduce_impl<ops::add<T>, N, T>::call(ops::add<T>(), intermediate.data());
181+
static constexpr size_t K = preferred_vector_size<T>::value;
182+
T result = {};
183+
184+
if constexpr (N / K > 0) {
185+
T accum[K] = {T {}};
186+
apply_impl<ops::multiply<T>, K, T, T, T>::call({}, accum, left, right);
187+
188+
#pragma unroll
189+
for (size_t i = 1; i < N / K; i++) {
190+
apply_impl<ops::fma<T>, K, T, T, T, T>::call(
191+
ops::fma<T> {},
192+
accum,
193+
left + i * K,
194+
right + i * K,
195+
accum);
196+
}
197+
198+
result = reduce_impl<ops::add<T>, K, T>::call({}, accum);
199+
}
200+
201+
if constexpr (N % K > 0) {
202+
for (size_t i = N - N % K; i < N; i++) {
203+
apply_impl<ops::fma<T>, 1, T, T, T, T>::call(
204+
{},
205+
&result,
206+
left + i,
207+
right + i,
208+
&result);
209+
}
210+
}
211+
212+
return result;
188213
}
189214
};
190215
} // namespace detail

include/kernel_float/triops.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,25 @@ namespace ops {
9292
template<typename T>
9393
struct fma {
9494
KERNEL_FLOAT_INLINE T operator()(T a, T b, T c) {
95-
return a * b + c;
95+
return ops::add<T> {}(ops::multiply<T> {}(a, b), c);
9696
}
9797
};
98+
} // namespace ops
99+
100+
namespace detail {
101+
template<typename T, size_t N>
102+
struct apply_impl<ops::fma<T>, N, T, T, T, T> {
103+
KERNEL_FLOAT_INLINE
104+
static void call(ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
105+
T temp[N];
106+
apply_impl<ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
107+
apply_impl<ops::add<T>, N, T, T, T>::call({}, output, temp, c);
108+
}
109+
};
110+
} // namespace detail
98111

99112
#if KERNEL_FLOAT_IS_DEVICE
113+
namespace ops {
100114
template<>
101115
struct fma<float> {
102116
KERNEL_FLOAT_INLINE float operator()(float a, float b, float c) {
@@ -110,8 +124,8 @@ struct fma<double> {
110124
return __fma_rn(a, b, c);
111125
}
112126
};
113-
#endif
114127
} // namespace ops
128+
#endif
115129

116130
/**
117131
* Computes the result of `a * b + c`. This is done in a single operation if possible for the given vector type.

0 commit comments

Comments
 (0)