Skip to content

Commit 0d24d71

Browse files
committed
Attempt to improve compilation times
1 parent de4a701 commit 0d24d71

File tree

22 files changed

+2995
-3164
lines changed

22 files changed

+2995
-3164
lines changed

include/kernel_float.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "kernel_float/bf16.h"
55
#include "kernel_float/binops.h"
6+
#include "kernel_float/cast.h"
67
#include "kernel_float/fp16.h"
78
#include "kernel_float/fp8.h"
89
#include "kernel_float/interface.h"

include/kernel_float/bf16.h

Lines changed: 68 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,94 +6,95 @@
66
#if KERNEL_FLOAT_BF16_AVAILABLE
77
#include <cuda_bf16.h>
88

9+
#include "binops.h"
10+
#include "cast.h"
911
#include "interface.h"
12+
#include "storage.h"
13+
#include "unops.h"
1014

1115
namespace kernel_float {
16+
KERNEL_FLOAT_DEFINE_COMMON_TYPE(__nv_bfloat16, bool)
1217
KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, __nv_bfloat16)
1318
KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, __nv_bfloat16)
1419

15-
struct vector_bfloat16x2 {
16-
static_assert(sizeof(__nv_bfloat16) * 2 == sizeof(__nv_bfloat162), "invalid size");
17-
static_assert(alignof(__nv_bfloat16) <= alignof(__nv_bfloat162), "invalid alignment");
18-
19-
KERNEL_FLOAT_INLINE vector_bfloat16x2(__nv_bfloat16 v = {}) noexcept : vector_ {v, v} {}
20-
KERNEL_FLOAT_INLINE vector_bfloat16x2(__nv_bfloat16 x, __nv_bfloat16 y) noexcept :
21-
vector_ {x, y} {}
22-
KERNEL_FLOAT_INLINE vector_bfloat16x2(__nv_bfloat162 xy) noexcept : vector_ {xy} {}
23-
24-
KERNEL_FLOAT_INLINE operator __nv_bfloat162() const noexcept {
25-
return vector_;
26-
}
27-
28-
KERNEL_FLOAT_INLINE __nv_bfloat16 get(const_index<0>) const {
29-
return vector_.x;
30-
}
31-
32-
KERNEL_FLOAT_INLINE __nv_bfloat16 get(const_index<1>) const {
33-
return vector_.y;
34-
}
20+
template<>
21+
struct vector_traits<__nv_bfloat162> {
22+
using value_type = __nv_bfloat16;
23+
static constexpr size_t size = 2;
3524

36-
KERNEL_FLOAT_INLINE void set(const_index<0>, __nv_bfloat16 v) {
37-
*this = vector_bfloat16x2(v, get(const_index<1> {}));
25+
KERNEL_FLOAT_INLINE
26+
static __nv_bfloat162 fill(__nv_bfloat16 value) {
27+
#if KERNEL_FLOAT_ON_DEVICE
28+
return __bfloat162bfloat162(value);
29+
#else
30+
return {value, value};
31+
#endif
3832
}
3933

40-
KERNEL_FLOAT_INLINE void set(const_index<1>, __nv_bfloat16 v) {
41-
*this = vector_bfloat16x2(get(const_index<0> {}), v);
34+
KERNEL_FLOAT_INLINE
35+
static __nv_bfloat162 create(__nv_bfloat16 low, __nv_bfloat16 high) {
36+
#if KERNEL_FLOAT_ON_DEVICE
37+
return __halves2bfloat162(low, high);
38+
#else
39+
return {low, high};
40+
#endif
4241
}
4342

44-
KERNEL_FLOAT_INLINE __nv_bfloat16 get(size_t index) const {
43+
KERNEL_FLOAT_INLINE
44+
static __nv_bfloat16 get(__nv_bfloat162 self, size_t index) {
45+
#if KERNEL_FLOAT_ON_DEVICE
4546
if (index == 0) {
46-
return get(const_index<0> {});
47+
return __low2bfloat16(self);
4748
} else {
48-
return get(const_index<1> {});
49+
return __high2bfloat16(self);
4950
}
51+
#else
52+
if (index == 0) {
53+
return self.x;
54+
} else {
55+
return self.y;
56+
}
57+
#endif
5058
}
5159

52-
KERNEL_FLOAT_INLINE void set(size_t index, __nv_bfloat16 value) const {
60+
KERNEL_FLOAT_INLINE
61+
static void set(__nv_bfloat162& self, size_t index, __nv_bfloat16 value) {
5362
if (index == 0) {
54-
set(const_index<0> {}, value);
63+
self.x = value;
5564
} else {
56-
set(const_index<1> {}, value);
65+
self.y = value;
5766
}
5867
}
59-
60-
private:
61-
__nv_bfloat162 vector_;
6268
};
6369

64-
template<>
65-
struct vector_traits<vector_bfloat16x2>:
66-
default_vector_traits<vector_bfloat16x2, __nv_bfloat16, 2> {};
67-
68-
template<>
69-
struct vector_traits<__nv_bfloat16>: vector_traits<vector_scalar<__nv_bfloat16>> {};
70-
71-
template<>
72-
struct vector_traits<__nv_bfloat162>: vector_traits<vector_bfloat16x2> {};
70+
template<size_t N>
71+
struct default_storage<__nv_bfloat16, N, Alignment::Maximum, enabled_t<(N >= 2)>> {
72+
using type = nested_array<__nv_bfloat162, N>;
73+
};
7374

74-
template<>
75-
struct default_vector_storage<__nv_bfloat16, 2> {
76-
using type = vector_bfloat16x2;
75+
template<size_t N>
76+
struct default_storage<__nv_bfloat16, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 == 0)>> {
77+
using type = nested_array<__nv_bfloat162, N>;
7778
};
7879

7980
#if KERNEL_FLOAT_ON_DEVICE
80-
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
81-
namespace ops { \
82-
template<> \
83-
struct NAME<__nv_bfloat16> { \
84-
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \
85-
return FUN1(input); \
86-
} \
87-
}; \
88-
} \
89-
namespace detail { \
90-
template<> \
91-
struct map_helper<ops::NAME<__nv_bfloat16>, vector_bfloat16x2, vector_bfloat16x2> { \
92-
KERNEL_FLOAT_INLINE static __nv_bfloat162 \
93-
call(ops::NAME<__nv_bfloat16>, const __nv_bfloat162& input) { \
94-
return FUN2(input); \
95-
} \
96-
}; \
81+
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
82+
namespace ops { \
83+
template<> \
84+
struct NAME<__nv_bfloat16> { \
85+
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \
86+
return FUN1(input); \
87+
} \
88+
}; \
89+
} \
90+
namespace detail { \
91+
template<> \
92+
struct map_helper<ops::NAME<__nv_bfloat16>, __nv_bfloat162, __nv_bfloat162> { \
93+
KERNEL_FLOAT_INLINE static __nv_bfloat162 \
94+
call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \
95+
return FUN2(input); \
96+
} \
97+
}; \
9798
}
9899

99100
KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2);
@@ -123,13 +124,9 @@ KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc);
123124
} \
124125
namespace detail { \
125126
template<> \
126-
struct zip_helper< \
127-
ops::NAME<__nv_bfloat16>, \
128-
vector_bfloat16x2, \
129-
vector_bfloat16x2, \
130-
vector_bfloat16x2> { \
127+
struct zip_helper<ops::NAME<__nv_bfloat16>, __nv_bfloat162, __nv_bfloat162, __nv_bfloat162> { \
131128
KERNEL_FLOAT_INLINE static __nv_bfloat162 \
132-
call(ops::NAME<__nv_bfloat16>, const __nv_bfloat162& left, const __nv_bfloat162& right) { \
129+
call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 left, __nv_bfloat162 right) { \
133130
return FUN2(left, right); \
134131
} \
135132
}; \
@@ -197,27 +194,10 @@ KERNEL_FLOAT_BF16_CAST(
197194
(unsigned long)(__bfloat162ull_rz(input)));
198195
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
199196

200-
namespace detail {
201-
template<>
202-
struct map_helper<ops::cast<__nv_bfloat16, float>, vector_storage<float, 2>, vector_bfloat16x2> {
203-
KERNEL_FLOAT_INLINE static vector_storage<float, 2>
204-
call(ops::cast<__nv_bfloat16, float>, __nv_bfloat162 input) noexcept {
205-
return __bfloat1622float2(input);
206-
}
207-
};
208-
209-
template<>
210-
struct map_helper<ops::cast<float, __nv_bfloat16>, vector_bfloat16x2, vector_storage<float, 2>> {
211-
KERNEL_FLOAT_INLINE static vector_bfloat16x2
212-
call(ops::cast<float, __nv_bfloat16>, const vector_storage<float, 2>& input) noexcept {
213-
return __float22bfloat162_rn(input);
214-
}
215-
};
216-
} // namespace detail
217-
218197
using bfloat16 = __nv_bfloat16;
219-
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16)
220-
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16)
198+
//KERNEL_FLOAT_TYPE_ALIAS(half, __nv_bfloat16)
199+
//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16)
200+
//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16)
221201

222202
} // namespace kernel_float
223203

include/kernel_float/binops.h

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,25 @@ struct zip_helper {
1515
template<size_t... Is>
1616
KERNEL_FLOAT_INLINE static Output
1717
call_with_indices(F fun, const Left& left, const Right& right, index_sequence<Is...> = {}) {
18-
return Output {fun(left.get(const_index<Is> {}), right.get(const_index<Is> {}))...};
18+
return vector_traits<Output>::create(fun(vector_get<Is>(left), vector_get<Is>(right))...);
1919
}
2020
};
2121

22-
template<typename F, typename T, typename L, typename R, size_t N>
23-
struct zip_helper<F, vector_compound<T, N>, vector_compound<L, N>, vector_compound<R, N>> {
24-
KERNEL_FLOAT_INLINE static vector_compound<T, N>
25-
call(F fun, const vector_compound<L, N>& left, const vector_compound<R, N>& right) {
26-
static constexpr size_t low_size = vector_compound<T, N>::low_size;
27-
static constexpr size_t high_size = vector_compound<T, N>::high_size;
28-
29-
return {
30-
zip_helper<
31-
F,
32-
vector_storage<T, low_size>,
33-
vector_storage<L, low_size>,
34-
vector_storage<R, low_size>>::call(fun, left.low(), right.low()),
35-
zip_helper<
36-
F,
37-
vector_storage<T, high_size>,
38-
vector_storage<L, high_size>,
39-
vector_storage<R, high_size>>::call(fun, left.high(), right.high())};
22+
template<typename F, typename V, size_t N>
23+
struct zip_helper<F, nested_array<V, N>, nested_array<V, N>, nested_array<V, N>> {
24+
KERNEL_FLOAT_INLINE static nested_array<V, N>
25+
call(F fun, const nested_array<V, N>& left, const nested_array<V, N>& right) {
26+
return call(fun, left, right, make_index_sequence<nested_array<V, N>::num_packets> {});
27+
}
28+
29+
private:
30+
template<size_t... Is>
31+
KERNEL_FLOAT_INLINE static nested_array<V, N> call(
32+
F fun,
33+
const nested_array<V, N>& left,
34+
const nested_array<V, N>& right,
35+
index_sequence<Is...>) {
36+
return {zip_helper<F, V, V, V>::call(fun, left[Is], right[Is])...};
4037
}
4138
};
4239
}; // namespace detail
@@ -48,7 +45,7 @@ template<typename... Ts>
4845
static constexpr size_t common_vector_size = common_size<vector_size<Ts>...>;
4946

5047
template<typename F, typename L, typename R>
51-
using zip_type = vector_storage<
48+
using zip_type = default_storage_type<
5249
result_t<F, vector_value_type<L>, vector_value_type<R>>,
5350
common_vector_size<L, R>>;
5451

@@ -63,16 +60,19 @@ using zip_type = vector_storage<
6360
* ``zip_common`` for that functionality.
6461
*/
6562
template<typename F, typename Left, typename Right, typename Output = zip_type<F, Left, Right>>
66-
KERNEL_FLOAT_INLINE Output zip(F fun, Left&& left, Right&& right) {
63+
KERNEL_FLOAT_INLINE vector<Output> zip(F fun, Left&& left, Right&& right) {
6764
static constexpr size_t N = vector_size<Output>;
68-
return detail::zip_helper<F, Output, into_vector_type<Left>, into_vector_type<Right>>::call(
65+
using LeftInput = default_storage_type<vector_value_type<Left>, N>;
66+
using RightInput = default_storage_type<vector_value_type<Right>, N>;
67+
68+
return detail::zip_helper<F, Output, LeftInput, RightInput>::call(
6969
fun,
70-
broadcast<N>(std::forward<Left>(left)),
71-
broadcast<N>(std::forward<Right>(right)));
70+
broadcast<LeftInput, Left>(std::forward<Left>(left)),
71+
broadcast<RightInput, Right>(std::forward<Right>(right)));
7272
}
7373

7474
template<typename F, typename L, typename R>
75-
using zip_common_type = vector_storage<
75+
using zip_common_type = default_storage_type<
7676
result_t<F, common_vector_value_type<L, R>, common_vector_value_type<L, R>>,
7777
common_vector_size<L, R>>;
7878

@@ -99,38 +99,50 @@ template<
9999
typename Left,
100100
typename Right,
101101
typename Output = zip_common_type<F, Left, Right>>
102-
KERNEL_FLOAT_INLINE Output zip_common(F fun, Left&& left, Right&& right) {
102+
KERNEL_FLOAT_INLINE vector<Output> zip_common(F fun, Left&& left, Right&& right) {
103103
static constexpr size_t N = vector_size<Output>;
104104
using C = common_t<vector_value_type<Left>, vector_value_type<Right>>;
105+
using Input = default_storage_type<C, N>;
105106

106-
return detail::zip_helper<F, Output, vector_storage<C, N>, vector_storage<C, N>>::call(
107+
return detail::zip_helper<F, Output, Input, Input>::call(
107108
fun,
108-
broadcast<C, N>(std::forward<Left>(left)),
109-
broadcast<C, N>(std::forward<Right>(right)));
109+
broadcast<Input, Left>(std::forward<Left>(left)),
110+
broadcast<Input, Right>(std::forward<Right>(right)));
110111
}
111112

112-
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \
113-
namespace ops { \
114-
template<typename T> \
115-
struct NAME { \
116-
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
117-
return T(EXPR); \
118-
} \
119-
}; \
120-
} \
121-
template<typename L, typename R, typename C = common_vector_value_type<L, R>> \
122-
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
123-
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
113+
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \
114+
namespace ops { \
115+
template<typename T> \
116+
struct NAME { \
117+
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
118+
return T(EXPR); \
119+
} \
120+
}; \
121+
} \
122+
template<typename L, typename R, typename C = common_vector_value_type<L, R>> \
123+
KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> NAME(L&& left, R&& right) { \
124+
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
124125
}
125126

126-
#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \
127-
KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \
128-
template< \
129-
typename L, \
130-
typename R, \
131-
typename C = enabled_t<is_vector<L> || is_vector<R>, common_vector_value_type<L, R>>> \
132-
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> operator OP(L&& left, R&& right) { \
133-
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
127+
#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \
128+
KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \
129+
template<typename L, typename R, typename C = common_vector_value_type<L, R>> \
130+
KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP( \
131+
const vector<L>& left, \
132+
const vector<R>& right) { \
133+
return zip_common(ops::NAME<C> {}, left, right); \
134+
} \
135+
template<typename L, typename R, typename C = common_vector_value_type<L, R>> \
136+
KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP( \
137+
const vector<L>& left, \
138+
const R& right) { \
139+
return zip_common(ops::NAME<C> {}, left, right); \
140+
} \
141+
template<typename L, typename R, typename C = common_vector_value_type<L, R>> \
142+
KERNEL_FLOAT_INLINE vector<zip_common_type<ops::NAME<C>, L, R>> operator OP( \
143+
const L& left, \
144+
const vector<R>& right) { \
145+
return zip_common(ops::NAME<C> {}, left, right); \
134146
}
135147

136148
KERNEL_FLOAT_DEFINE_BINARY_OP(add, +)
@@ -153,7 +165,6 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^)
153165
// clang-format off
154166
template<template<typename T> typename F, typename L, typename R>
155167
static constexpr bool vector_assign_allowed =
156-
is_vector<L> &&
157168
common_vector_size<L, R> == vector_size<L> &&
158169
is_implicit_convertible<
159170
result_t<
@@ -170,9 +181,9 @@ static constexpr bool vector_assign_allowed =
170181
typename L, \
171182
typename R, \
172183
typename T = enabled_t<vector_assign_allowed<ops::NAME, L, R>, vector_value_type<L>>> \
173-
KERNEL_FLOAT_INLINE L& operator OP(L& lhs, R&& rhs) { \
184+
KERNEL_FLOAT_INLINE vector<L>& operator OP(vector<L>& lhs, const R& rhs) { \
174185
using F = ops::NAME<T>; \
175-
lhs = zip_common<F, L&, R, into_vector_type<L>>(F {}, lhs, std::forward<R>(rhs)); \
186+
lhs = zip_common<F, const L&, const R&, L>(F {}, lhs.storage(), rhs); \
176187
return lhs; \
177188
}
178189

0 commit comments

Comments
 (0)