Skip to content

Commit 7acff4c

Browse files
committed
Use raw pointers in apply_impl and reduce_impl
1 parent ebd0967 commit 7acff4c

File tree

8 files changed

+108
-107
lines changed

8 files changed

+108
-107
lines changed

include/kernel_float/bf16.h

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -49,66 +49,55 @@ struct zip_bfloat16x2 {
4949

5050
template<typename F, size_t N>
5151
struct apply_impl<F, N, __nv_bfloat16, __nv_bfloat16> {
52-
KERNEL_FLOAT_INLINE static vector_storage<__nv_bfloat16, N>
53-
call(F fun, const vector_storage<__nv_bfloat16, N>& input) {
54-
vector_storage<__nv_bfloat16, N> result;
55-
52+
KERNEL_FLOAT_INLINE static void call(F fun, __nv_bfloat16* result, const __nv_bfloat16* input) {
5653
#pragma unroll
57-
for (size_t i = 0; i + 2 <= N; i += 2) {
58-
__nv_bfloat162 a = {input.data()[i], input.data()[i + 1]};
54+
for (size_t i = 0; 2 * i + 1 < N; i++) {
55+
__nv_bfloat162 a = {input[2 * i], input[2 * i + 1]};
5956
__nv_bfloat162 b = map_bfloat16x2<F>::call(fun, a);
60-
result.data()[i + 0] = b.x;
61-
result.data()[i + 1] = b.y;
57+
result[2 * i + 0] = b.x;
58+
result[2 * i + 1] = b.y;
6259
}
6360

6461
if (N % 2 != 0) {
65-
result.data()[N - 1] = fun(input.data()[N - 1]);
62+
result[N - 1] = fun(input[N - 1]);
6663
}
67-
68-
return result;
6964
}
7065
};
7166

7267
template<typename F, size_t N>
7368
struct apply_impl<F, N, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> {
74-
KERNEL_FLOAT_INLINE static vector_storage<__nv_bfloat16, N> call(
75-
F fun,
76-
const vector_storage<__nv_bfloat16, N>& left,
77-
const vector_storage<__nv_bfloat16, N>& right) {
78-
vector_storage<__nv_bfloat16, N> result;
69+
KERNEL_FLOAT_INLINE static void
70+
call(F fun, __nv_bfloat16* result, const __nv_bfloat16* left, const __nv_bfloat16* right) {
7971
#pragma unroll
80-
for (size_t i = 0; i + 2 <= N; i += 2) {
81-
__nv_bfloat162 a = {left.data()[i], left.data()[i + 1]};
82-
__nv_bfloat162 b = {right.data()[i], right.data()[i + 1]};
72+
for (size_t i = 0; 2 * i + 1 < N; i++) {
73+
__nv_bfloat162 a = {left[2 * i], left[2 * i + 1]};
74+
__nv_bfloat162 b = {right[2 * i], right[2 * i + 1]};
8375
__nv_bfloat162 c = zip_bfloat16x2<F>::call(fun, a, b);
84-
result.data()[i + 0] = c.x;
85-
result.data()[i + 1] = c.y;
76+
result[2 * i + 0] = c.x;
77+
result[2 * i + 1] = c.y;
8678
}
8779

8880
if (N % 2 != 0) {
89-
result.data()[N - 1] = fun(left.data()[N - 1], right.data()[N - 1]);
81+
result[N - 1] = fun(left[N - 1], right[N - 1]);
9082
}
91-
92-
return result;
9383
}
9484
};
9585

9686
template<typename F, size_t N>
9787
struct reduce_impl<F, N, __nv_bfloat16, enable_if_t<(N >= 2)>> {
98-
KERNEL_FLOAT_INLINE static __nv_bfloat16
99-
call(F fun, const vector_storage<__nv_bfloat16, N>& input) {
100-
__nv_bfloat162 accum = {input.data()[0], input.data()[1]};
88+
KERNEL_FLOAT_INLINE static __nv_bfloat16 call(F fun, const __nv_bfloat16* input) {
89+
__nv_bfloat162 accum = {input[0], input[1]};
10190

10291
#pragma unroll
103-
for (size_t i = 2; i + 2 <= N; i += 2) {
104-
__nv_bfloat162 a = {input.data()[i], input.data()[i + 1]};
92+
for (size_t i = 0; 2 * i + 1 < N; i++) {
93+
__nv_bfloat162 a = {input[2 * i], input[2 * i + 1]};
10594
accum = zip_bfloat16x2<F>::call(fun, accum, a);
10695
}
10796

10897
__nv_bfloat16 result = fun(accum.x, accum.y);
10998

11099
if (N % 2 != 0) {
111-
result = fun(result, input.data()[N - 1]);
100+
result = fun(result, input[N - 1]);
112101
}
113102

114103
return result;
@@ -126,6 +115,7 @@ struct reduce_impl<F, N, __nv_bfloat16, enable_if_t<(N >= 2)>> {
126115
}; \
127116
}
128117

118+
// There operations are not implemented in half precision, so they are forward to single precision
129119
KERNEL_FLOAT_BF16_UNARY_FORWARD(tan)
130120
KERNEL_FLOAT_BF16_UNARY_FORWARD(asin)
131121
KERNEL_FLOAT_BF16_UNARY_FORWARD(acos)
@@ -243,32 +233,22 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
243233
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), double(__bfloat162float(input)));
244234
KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input));
245235

236+
// clang-format off
246237
// there are no official char casts. Instead, cast to int and then to char
247238
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
248-
KERNEL_FLOAT_BF16_CAST(
249-
signed char,
250-
__int2bfloat16_rn(input),
251-
(signed char)__bfloat162int_rz(input));
252-
KERNEL_FLOAT_BF16_CAST(
253-
unsigned char,
254-
__int2bfloat16_rn(input),
255-
(unsigned char)__bfloat162int_rz(input));
239+
KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input));
240+
KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input));
256241

257242
KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input));
258243
KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input));
259-
KERNEL_FLOAT_BF16_CAST(
260-
signed long,
261-
__ll2bfloat16_rn(input),
262-
(signed long)(__bfloat162ll_rz(input)));
244+
KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input)));
263245
KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
264246

265247
KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input));
266248
KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input));
267-
KERNEL_FLOAT_BF16_CAST(
268-
unsigned long,
269-
__ull2bfloat16_rn(input),
270-
(unsigned long)(__bfloat162ull_rz(input)));
249+
KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input)));
271250
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
251+
// clang-format on
272252

273253
using bfloat16 = __nv_bfloat16;
274254
//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16)

include/kernel_float/binops.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,16 @@ KERNEL_FLOAT_INLINE zip_type<F, L, R> zip(F fun, const L& left, const R& right)
2929
using B = vector_value_type<R>;
3030
using O = result_t<F, A, B>;
3131
using E = broadcast_vector_extent_type<L, R>;
32+
vector_storage<O, E::value> result;
3233

33-
return detail::apply_impl<F, E::value, O, A, B>::call(
34+
detail::apply_impl<F, E::value, O, A, B>::call(
3435
fun,
35-
detail::broadcast_impl<A, vector_extent_type<L>, E>::call(into_vector_storage(left)),
36-
detail::broadcast_impl<B, vector_extent_type<R>, E>::call(into_vector_storage(right)));
36+
result.data(),
37+
detail::broadcast_impl<A, vector_extent_type<L>, E>::call(into_vector_storage(left)).data(),
38+
detail::broadcast_impl<B, vector_extent_type<R>, E>::call(into_vector_storage(right))
39+
.data());
40+
41+
return result;
3742
}
3843

3944
template<typename F, typename L, typename R>
@@ -60,12 +65,19 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6065
using O = result_t<F, T, T>;
6166
using E = broadcast_vector_extent_type<L, R>;
6267

63-
return detail::apply_impl<F, E::value, O, T, T>::call(
68+
vector_storage<O, E::value> result;
69+
70+
detail::apply_impl<F, E::value, O, T, T>::call(
6471
fun,
72+
result.data(),
6573
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
66-
into_vector_storage(left)),
74+
into_vector_storage(left))
75+
.data(),
6776
detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
68-
into_vector_storage(right)));
77+
into_vector_storage(right))
78+
.data());
79+
80+
return result;
6981
}
7082

7183
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \

include/kernel_float/conversion.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ struct convert_impl {
178178
KERNEL_FLOAT_INLINE
179179
static vector_storage<T2, E2::value> call(vector_storage<T, E::value> input) {
180180
using F = ops::cast<T, T2, M>;
181-
vector_storage<T2, E::value> intermediate =
182-
detail::apply_impl<F, E::value, T2, T>::call(F {}, input);
181+
vector_storage<T2, E::value> intermediate;
182+
detail::apply_impl<F, E::value, T2, T>::call(F {}, intermediate.data(), input.data());
183183
return detail::broadcast_impl<T2, E, E2>::call(intermediate);
184184
}
185185
};
@@ -208,7 +208,10 @@ struct convert_impl<T, E, T2, E, M> {
208208
KERNEL_FLOAT_INLINE
209209
static vector_storage<T2, E::value> call(vector_storage<T, E::value> input) {
210210
using F = ops::cast<T, T2, M>;
211-
return detail::apply_impl<F, E::value, T2, T>::call(F {}, input);
211+
212+
vector_storage<T2, E::value> result;
213+
detail::apply_impl<F, E::value, T2, T>::call(F {}, result.data(), input.data());
214+
return result;
212215
}
213216
};
214217
} // namespace detail

include/kernel_float/fp16.h

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,63 +47,55 @@ struct zip_halfx2 {
4747

4848
template<typename F, size_t N>
4949
struct apply_impl<F, N, __half, __half> {
50-
KERNEL_FLOAT_INLINE static vector_storage<__half, N>
51-
call(F fun, const vector_storage<__half, N>& input) {
52-
vector_storage<__half, N> result;
53-
50+
KERNEL_FLOAT_INLINE static void call(F fun, __half* result, const __half* input) {
5451
#pragma unroll
55-
for (size_t i = 0; i + 2 <= N; i += 2) {
56-
__half2 a = {input.data()[i], input.data()[i + 1]};
52+
for (size_t i = 0; 2 * i + 1 < N; i++) {
53+
__half2 a = {input[2 * i], input[2 * i + 1]};
5754
__half2 b = map_halfx2<F>::call(fun, a);
58-
result.data()[i + 0] = b.x;
59-
result.data()[i + 1] = b.y;
55+
result[2 * i + 0] = b.x;
56+
result[2 * i + 1] = b.y;
6057
}
6158

6259
if (N % 2 != 0) {
63-
result.data()[N - 1] = fun(input.data()[N - 1]);
60+
result[N - 1] = fun(input[N - 1]);
6461
}
65-
66-
return result;
6762
}
6863
};
6964

7065
template<typename F, size_t N>
7166
struct apply_impl<F, N, __half, __half, __half> {
72-
KERNEL_FLOAT_INLINE static vector_storage<__half, N>
73-
call(F fun, const vector_storage<__half, N>& left, const vector_storage<__half, N>& right) {
74-
vector_storage<__half, N> result;
67+
KERNEL_FLOAT_INLINE static void
68+
call(F fun, __half* result, const __half* left, const __half* right) {
7569
#pragma unroll
76-
for (size_t i = 0; i + 2 <= N; i += 2) {
77-
__half2 a = {left.data()[i], left.data()[i + 1]};
78-
__half2 b = {right.data()[i], right.data()[i + 1]};
70+
for (size_t i = 0; 2 * i + 1 < N; i++) {
71+
__half2 a = {left[2 * i], left[2 * i + 1]};
72+
__half2 b = {right[2 * i], right[2 * i + 1]};
7973
__half2 c = zip_halfx2<F>::call(fun, a, b);
80-
result.data()[i + 0] = c.x;
81-
result.data()[i + 1] = c.y;
74+
result[2 * i + 0] = c.x;
75+
result[2 * i + 1] = c.y;
8276
}
8377

8478
if (N % 2 != 0) {
85-
result.data()[N - 1] = fun(left.data()[N - 1], right.data()[N - 1]);
79+
result[N - 1] = fun(left[N - 1], right[N - 1]);
8680
}
87-
88-
return result;
8981
}
9082
};
9183

9284
template<typename F, size_t N>
9385
struct reduce_impl<F, N, __half, enable_if_t<(N >= 2)>> {
94-
KERNEL_FLOAT_INLINE static __half call(F fun, const vector_storage<__half, N>& input) {
95-
__half2 accum = {input.data()[0], input.data()[1]};
86+
KERNEL_FLOAT_INLINE static __half call(F fun, const __half* input) {
87+
__half2 accum = {input[0], input[1]};
9688

9789
#pragma unroll
98-
for (size_t i = 2; i + 2 <= N; i += 2) {
99-
__half2 a = {input.data()[i], input.data()[i + 1]};
90+
for (size_t i = 0; 2 * i + 1 < N; i++) {
91+
__half2 a = {input[2 * i], input[2 * i + 1]};
10092
accum = zip_halfx2<F>::call(fun, accum, a);
10193
}
10294

10395
__half result = fun(accum.x, accum.y);
10496

10597
if (N % 2 != 0) {
106-
result = fun(result, input.data()[N - 1]);
98+
result = fun(result, input[N - 1]);
10799
}
108100

109101
return result;
@@ -122,6 +114,7 @@ struct reduce_impl<F, N, __half, enable_if_t<(N >= 2)>> {
122114
}; \
123115
}
124116

117+
// There operations are not implemented in half precision, so they are forward to single precision
125118
KERNEL_FLOAT_FP16_UNARY_FORWARD(tan)
126119
KERNEL_FLOAT_FP16_UNARY_FORWARD(asin)
127120
KERNEL_FLOAT_FP16_UNARY_FORWARD(acos)

include/kernel_float/reduce.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@ namespace kernel_float {
77
namespace detail {
88
template<typename F, size_t N, typename T, typename = void>
99
struct reduce_impl {
10-
KERNEL_FLOAT_INLINE static T call(F fun, const vector_storage<T, N>& input) {
10+
KERNEL_FLOAT_INLINE static T call(F fun, const T* input) {
1111
return call(fun, input, make_index_sequence<N> {});
1212
}
1313

1414
private:
1515
template<size_t... Is>
16-
KERNEL_FLOAT_INLINE static T
17-
call(F fun, const vector_storage<T, N>& input, index_sequence<0, Is...>) {
18-
T result = input.data()[0];
16+
KERNEL_FLOAT_INLINE static T call(F fun, const T* input, index_sequence<0, Is...>) {
17+
T result = input[0];
1918
#pragma unroll
2019
for (size_t i = 1; i < N; i++) {
21-
result = fun(result, input.data()[i]);
20+
result = fun(result, input[i]);
2221
}
2322
return result;
2423
}
@@ -43,7 +42,7 @@ template<typename F, typename V>
4342
KERNEL_FLOAT_INLINE vector_value_type<V> reduce(F fun, const V& input) {
4443
return detail::reduce_impl<F, vector_extent<V>, vector_value_type<V>>::call(
4544
fun,
46-
into_vector_storage(input));
45+
into_vector_storage(input).data());
4746
}
4847

4948
/**

include/kernel_float/triops.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,22 @@ template<
3939
typename E = broadcast_vector_extent_type<C, L, R>>
4040
KERNEL_FLOAT_INLINE vector<T, E> where(const C& cond, const L& true_values, const R& false_values) {
4141
using F = ops::conditional<T>;
42+
vector_storage<T, E::value> result;
4243

43-
return detail::apply_impl<F, E::value, T, bool, T, T>::call(
44+
detail::apply_impl<F, E::value, T, bool, T, T>::call(
4445
F {},
46+
result.data(),
4547
detail::convert_impl<vector_value_type<C>, vector_extent_type<C>, bool, E>::call(
46-
into_vector_storage(cond)),
48+
into_vector_storage(cond))
49+
.data(),
4750
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
48-
into_vector_storage(true_values)),
51+
into_vector_storage(true_values))
52+
.data(),
4953
detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
50-
into_vector_storage(false_values)));
54+
into_vector_storage(false_values))
55+
.data());
56+
57+
return result;
5158
}
5259

5360
/**
@@ -117,15 +124,22 @@ template<
117124
typename E = broadcast_vector_extent_type<A, B, C>>
118125
KERNEL_FLOAT_INLINE vector<T, E> fma(const A& a, const B& b, const C& c) {
119126
using F = ops::fma<T>;
127+
vector_storage<T, E::value> result;
120128

121-
return detail::apply_impl<F, E::value, T, T, T, T>::call(
129+
detail::apply_impl<F, E::value, T, T, T, T>::call(
122130
F {},
131+
result.data(),
123132
detail::convert_impl<vector_value_type<A>, vector_extent_type<A>, T, E>::call(
124-
into_vector_storage(a)),
133+
into_vector_storage(a))
134+
.data(),
125135
detail::convert_impl<vector_value_type<B>, vector_extent_type<B>, T, E>::call(
126-
into_vector_storage(b)),
136+
into_vector_storage(b))
137+
.data(),
127138
detail::convert_impl<vector_value_type<C>, vector_extent_type<C>, T, E>::call(
128-
into_vector_storage(c)));
139+
into_vector_storage(c))
140+
.data());
141+
142+
return result;
129143
}
130144

131145
} // namespace kernel_float

0 commit comments

Comments
 (0)