Skip to content

Commit ebd0967

Browse files
committed
Rename several helper structs from X_helper to X_impl
1 parent 90372b2 commit ebd0967

File tree

9 files changed

+165
-129
lines changed

9 files changed

+165
-129
lines changed

include/kernel_float/bf16.h

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct apply_impl<F, N, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> {
9494
};
9595

9696
template<typename F, size_t N>
97-
struct reduce_helper<F, N, __nv_bfloat16, enabled_t<(N >= 2)>> {
97+
struct reduce_impl<F, N, __nv_bfloat16, enable_if_t<(N >= 2)>> {
9898
KERNEL_FLOAT_INLINE static __nv_bfloat16
9999
call(F fun, const vector_storage<__nv_bfloat16, N>& input) {
100100
__nv_bfloat162 accum = {input.data()[0], input.data()[1]};
@@ -276,38 +276,54 @@ using bfloat16 = __nv_bfloat16;
276276

277277
#if KERNEL_FLOAT_IS_DEVICE
278278
namespace detail {
279+
template<>
280+
struct dot_impl<__nv_bfloat16, 0> {
281+
KERNEL_FLOAT_INLINE
282+
static __nv_bfloat16 call(
283+
const vector_storage<__nv_bfloat16, 0>& left,
284+
const vector_storage<__nv_bfloat16, 0>& right) {
285+
return __nv_bfloat16(0);
286+
}
287+
};
288+
289+
template<>
290+
struct dot_impl<__nv_bfloat16, 1> {
291+
KERNEL_FLOAT_INLINE
292+
static __nv_bfloat16 call(
293+
const vector_storage<__nv_bfloat16, 1>& left,
294+
const vector_storage<__nv_bfloat16, 1>& right) {
295+
return __hmul(left.data()[0], right.data()[0]);
296+
}
297+
};
298+
279299
template<size_t N>
280-
struct dot_helper<__nv_bfloat16, N> {
300+
struct dot_impl<__nv_bfloat16, N> {
301+
static_assert(N >= 2, "internal error");
302+
281303
KERNEL_FLOAT_INLINE
282304
static __nv_bfloat16 call(
283305
const vector_storage<__nv_bfloat16, N>& left,
284306
const vector_storage<__nv_bfloat16, N>& right) {
285-
if (N == 0) {
286-
return __nv_bfloat16(0);
287-
} else if (N == 1) {
288-
return __hmul(left.data()[0], right.data()[0]);
289-
} else {
290-
__nv_bfloat162 first_a = {left.data()[0], left.data()[1]};
291-
__nv_bfloat162 first_b = {right.data()[0], right.data()[1]};
292-
__nv_bfloat162 accum = __hmul2(first_a, first_b);
307+
__nv_bfloat162 first_a = {left.data()[0], left.data()[1]};
308+
__nv_bfloat162 first_b = {right.data()[0], right.data()[1]};
309+
__nv_bfloat162 accum = __hmul2(first_a, first_b);
293310

294311
#pragma unroll
295-
for (size_t i = 2; i + 2 <= N; i += 2) {
296-
__nv_bfloat162 a = {left.data()[i], left.data()[i + 1]};
297-
__nv_bfloat162 b = {right.data()[i], right.data()[i + 1]};
298-
accum = __hfma2(a, b, accum);
299-
}
300-
301-
__nv_bfloat16 result = __hadd(accum.x, accum.y);
312+
for (size_t i = 2; i + 2 <= N; i += 2) {
313+
__nv_bfloat162 a = {left.data()[i], left.data()[i + 1]};
314+
__nv_bfloat162 b = {right.data()[i], right.data()[i + 1]};
315+
accum = __hfma2(a, b, accum);
316+
}
302317

303-
if (N % 2 != 0) {
304-
__nv_bfloat16 a = left.data()[N - 1];
305-
__nv_bfloat16 b = right.data()[N - 1];
306-
result = __hfma(a, b, result);
307-
}
318+
__nv_bfloat16 result = __hadd(accum.x, accum.y);
308319

309-
return result;
320+
if (N % 2 != 0) {
321+
__nv_bfloat16 a = left.data()[N - 1];
322+
__nv_bfloat16 b = right.data()[N - 1];
323+
result = __hfma(a, b, result);
310324
}
325+
326+
return result;
311327
}
312328
};
313329
} // namespace detail

include/kernel_float/binops.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ using zip_common_type = vector<
4949
* Example
5050
* =======
5151
* ```
52-
* vec<int, 3> a = {1.0f, 2.0f, 3.0f};
52+
* vec<float, 3> a = {1.0f, 2.0f, 3.0f};
5353
* vec<int, 3> b = {4, 5, 6};
54-
* vec<int, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
54+
* vec<float, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
5555
* ```
5656
*/
5757
template<typename F, typename L, typename R>
@@ -62,9 +62,9 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6262

6363
return detail::apply_impl<F, E::value, O, T, T>::call(
6464
fun,
65-
detail::convert_helper<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
65+
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
6666
into_vector_storage(left)),
67-
detail::convert_helper<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
67+
detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
6868
into_vector_storage(right)));
6969
}
7070

@@ -139,7 +139,7 @@ static constexpr bool is_vector_assign_allowed =
139139
typename T, \
140140
typename E, \
141141
typename R, \
142-
typename = enabled_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
142+
typename = enable_if_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
143143
KERNEL_FLOAT_INLINE vector<T, E>& operator OP(vector<T, E>& lhs, const R& rhs) { \
144144
using F = ops::NAME<T>; \
145145
lhs = zip_common(F {}, lhs, rhs); \
@@ -249,7 +249,7 @@ struct bit_xor<double> {
249249

250250
namespace detail {
251251
template<typename T>
252-
struct cross_helper {
252+
struct cross_impl {
253253
KERNEL_FLOAT_INLINE
254254
static vector<T, extent<3>>
255255
call(const vector_storage<T, 3>& av, const vector_storage<T, 3>& bv) {
@@ -275,9 +275,9 @@ template<
275275
typename R,
276276
typename T = promoted_vector_value_type<L, R>,
277277
typename =
278-
enabled_t<is_vector_broadcastable<L, extent<3>> && is_vector_broadcastable<R, extent<3>>>>
278+
enable_if_t<is_vector_broadcastable<L, extent<3>> && is_vector_broadcastable<R, extent<3>>>>
279279
KERNEL_FLOAT_INLINE vector<T, extent<3>> cross(const L& left, const R& right) {
280-
return detail::cross_helper<T>::call(convert_storage<T, 3>(left), convert_storage<T, 3>(right));
280+
return detail::cross_impl<T>::call(convert_storage<T, 3>(left), convert_storage<T, 3>(right));
281281
}
282282

283283
} // namespace kernel_float

include/kernel_float/conversion.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ template<typename... Vs>
9999
using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
100100

101101
template<typename From, typename To>
102-
static constexpr bool is_broadcastable = is_same<broadcast_extent<From, To>, To>;
102+
static constexpr bool is_broadcastable = is_same_type<broadcast_extent<From, To>, To>;
103103

104104
template<typename V, typename To>
105105
static constexpr bool is_vector_broadcastable = is_broadcastable<vector_extent_type<V>, To>;
@@ -169,8 +169,12 @@ broadcast_like(const V& input, const R& other) {
169169
}
170170

171171
namespace detail {
172+
/**
173+
* Convert vector of element type `T` and extent type `E` to vector of element type `T2` and extent type `E2`.
174+
* Specialization exist for the cases where `T==T2` and/or `E==E2`.
175+
*/
172176
template<typename T, typename E, typename T2, typename E2, RoundingMode M = RoundingMode::ANY>
173-
struct convert_helper {
177+
struct convert_impl {
174178
KERNEL_FLOAT_INLINE
175179
static vector_storage<T2, E2::value> call(vector_storage<T, E::value> input) {
176180
using F = ops::cast<T, T2, M>;
@@ -180,24 +184,27 @@ struct convert_helper {
180184
}
181185
};
182186

187+
// T == T2, E == E2
183188
template<typename T, typename E, RoundingMode M>
184-
struct convert_helper<T, E, T, E, M> {
189+
struct convert_impl<T, E, T, E, M> {
185190
KERNEL_FLOAT_INLINE
186191
static vector_storage<T, E::value> call(vector_storage<T, E::value> input) {
187192
return input;
188193
}
189194
};
190195

196+
// T == T2, E != E2
191197
template<typename T, typename E, typename E2, RoundingMode M>
192-
struct convert_helper<T, E, T, E2, M> {
198+
struct convert_impl<T, E, T, E2, M> {
193199
KERNEL_FLOAT_INLINE
194200
static vector_storage<T, E2::value> call(vector_storage<T, E::value> input) {
195201
return detail::broadcast_impl<T, E, E2>::call(input);
196202
}
197203
};
198204

205+
// T != T2, E == E2
199206
template<typename T, typename E, typename T2, RoundingMode M>
200-
struct convert_helper<T, E, T2, E, M> {
207+
struct convert_impl<T, E, T2, E, M> {
201208
KERNEL_FLOAT_INLINE
202209
static vector_storage<T2, E::value> call(vector_storage<T, E::value> input) {
203210
using F = ops::cast<T, T2, M>;
@@ -208,8 +215,8 @@ struct convert_helper<T, E, T2, E, M> {
208215

209216
template<typename R, size_t N, RoundingMode M = RoundingMode::ANY, typename V>
210217
KERNEL_FLOAT_INLINE vector_storage<R, N> convert_storage(const V& input, extent<N> new_size = {}) {
211-
return detail::convert_helper<vector_value_type<V>, vector_extent_type<V>, R, extent<N>, M>::
212-
call(into_vector_storage(input));
218+
return detail::convert_impl<vector_value_type<V>, vector_extent_type<V>, R, extent<N>, M>::call(
219+
into_vector_storage(input));
213220
}
214221

215222
/**

include/kernel_float/fp16.h

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct apply_impl<F, N, __half, __half, __half> {
9090
};
9191

9292
template<typename F, size_t N>
93-
struct reduce_helper<F, N, __half, enabled_t<(N >= 2)>> {
93+
struct reduce_impl<F, N, __half, enable_if_t<(N >= 2)>> {
9494
KERNEL_FLOAT_INLINE static __half call(F fun, const vector_storage<__half, N>& input) {
9595
__half2 accum = {input.data()[0], input.data()[1]};
9696

@@ -256,37 +256,51 @@ using half = __half;
256256

257257
#if KERNEL_FLOAT_IS_DEVICE
258258
namespace detail {
259+
template<>
260+
struct dot_impl<__half, 0> {
261+
KERNEL_FLOAT_INLINE
262+
static __half
263+
call(const vector_storage<__half, 0>& left, const vector_storage<__half, 0>& right) {
264+
return __half(0);
265+
}
266+
};
267+
268+
template<>
269+
struct dot_impl<__half, 1> {
270+
KERNEL_FLOAT_INLINE
271+
static __half
272+
call(const vector_storage<__half, 1>& left, const vector_storage<__half, 1>& right) {
273+
return __hmul(left.data()[0], right.data()[0]);
274+
}
275+
};
276+
259277
template<size_t N>
260-
struct dot_helper<__half, N> {
278+
struct dot_impl<__half, N> {
279+
static_assert(N >= 2, "internal error");
280+
261281
KERNEL_FLOAT_INLINE
262282
static __half
263283
call(const vector_storage<__half, N>& left, const vector_storage<__half, N>& right) {
264-
if (N == 0) {
265-
return __half(0);
266-
} else if (N == 1) {
267-
return __hmul(left.data()[0], right.data()[0]);
268-
} else {
269-
__half2 first_a = {left.data()[0], left.data()[1]};
270-
__half2 first_b = {right.data()[0], right.data()[1]};
271-
__half2 accum = __hmul2(first_a, first_b);
284+
__half2 first_a = {left.data()[0], left.data()[1]};
285+
__half2 first_b = {right.data()[0], right.data()[1]};
286+
__half2 accum = __hmul2(first_a, first_b);
272287

273288
#pragma unroll
274-
for (size_t i = 2; i + 2 <= N; i += 2) {
275-
__half2 a = {left.data()[i], left.data()[i + 1]};
276-
__half2 b = {right.data()[i], right.data()[i + 1]};
277-
accum = __hfma2(a, b, accum);
278-
}
279-
280-
__half result = __hadd(accum.x, accum.y);
289+
for (size_t i = 2; i + 2 <= N; i += 2) {
290+
__half2 a = {left.data()[i], left.data()[i + 1]};
291+
__half2 b = {right.data()[i], right.data()[i + 1]};
292+
accum = __hfma2(a, b, accum);
293+
}
281294

282-
if (N % 2 != 0) {
283-
__half a = left.data()[N - 1];
284-
__half b = right.data()[N - 1];
285-
result = __hfma(a, b, result);
286-
}
295+
__half result = __hadd(accum.x, accum.y);
287296

288-
return result;
297+
if (N % 2 != 0) {
298+
__half a = left.data()[N - 1];
299+
__half b = right.data()[N - 1];
300+
result = __hfma(a, b, result);
289301
}
302+
303+
return result;
290304
}
291305
};
292306
} // namespace detail

0 commit comments

Comments
 (0)