Skip to content

Commit cc846b6

Browse files
committed
Update single include
1 parent 1212e8f commit cc846b6

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

single_include/kernel_float.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//================================================================================
22
// this file has been auto-generated, do not modify its contents!
3-
// date: 2023-08-24 20:18:55.064697
4-
// git hash: df42b93bfd36d8d9f1a397218cd91ebe1c13325f
3+
// date: 2023-08-28 14:29:52.760763
4+
// git hash: 31ffbb7ca20f9c4a1c43b37e06c99600a8f15b91
55
//================================================================================
66

77
#ifndef KERNEL_FLOAT_MACROS_H
@@ -1667,6 +1667,9 @@ namespace kernel_float {
16671667

16681668
template<typename T = double>
16691669
struct constant {
1670+
template<typename R>
1671+
KERNEL_FLOAT_INLINE explicit constexpr constant(const constant<R>& that) : value_(that.get()) {}
1672+
16701673
KERNEL_FLOAT_INLINE
16711674
constexpr constant(T value = {}) : value_(value) {}
16721675

@@ -1684,14 +1687,20 @@ struct constant {
16841687
T value_;
16851688
};
16861689

1690+
// Deduction guide for `constant<T>`
1691+
#if defined(__cpp_deduction_guides)
1692+
template<typename T>
1693+
constant(T&&) -> constant<decay_t<T>>;
1694+
#endif
1695+
16871696
template<typename T = double>
16881697
KERNEL_FLOAT_INLINE constexpr constant<T> make_constant(T value) {
16891698
return value;
16901699
}
16911700

16921701
template<typename L, typename R>
16931702
struct promote_type<constant<L>, constant<R>> {
1694-
using type = typename promote_type<L, R>::type;
1703+
using type = constant<typename promote_type<L, R>::type>;
16951704
};
16961705

16971706
template<typename L, typename R>
@@ -3651,8 +3660,19 @@ struct dot_helper<__nv_bfloat16, N> {
36513660

36523661
namespace kernel_float {
36533662
KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input));
3663+
3664+
template<>
3665+
struct promote_type<__nv_bfloat16, __half> {
3666+
using type = float;
3667+
}
3668+
3669+
template<>
3670+
struct promote_type<__half, __nv_bfloat16> {
3671+
using type = float;
36543672
}
36553673

3674+
} // namespace kernel_float
3675+
36563676
#endif // KERNEL_FLOAT_FP16_AVAILABLE
36573677
#endif
36583678

@@ -3663,6 +3683,8 @@ KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input))
36633683

36643684

36653685

3686+
3687+
36663688
namespace kernel_float {
36673689
namespace prelude {
36683690
namespace kf = ::kernel_float;
@@ -3753,6 +3775,18 @@ static constexpr kconstant<long long int> operator""_c(unsigned long long int v)
37533775
return static_cast<long long int>(v);
37543776
}
37553777

3778+
// Deduction guides for aliases are only supported from C++20
3779+
#if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201907L
3780+
template<typename T>
3781+
kscalar(T&&) -> kscalar<decay_t<T>>;
3782+
3783+
template<typename... Args>
3784+
kvec(Args&&...) -> kvec<promote_t<Args...>, sizeof...(Args)>;
3785+
3786+
template<typename T>
3787+
kconstant(T&&) -> kconstant<decay_t<T>>;
3788+
#endif
3789+
37563790
} // namespace prelude
37573791
} // namespace kernel_float
37583792

0 commit comments

Comments
 (0)