Skip to content

Commit 90372b2

Browse files
committed
Small bug fixes
1 parent 64f2190 commit 90372b2

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

include/kernel_float/base.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,17 @@ struct into_vector_traits<aligned_array<T, N, A>> {
226226

227227
#define KERNEL_FLOAT_DEFINE_VECTOR_TYPE(T, T1, T2, T3, T4) \
228228
template<> \
229+
struct into_vector_traits<::T1> { \
230+
using value_type = T; \
231+
using extent_type = extent<1>; \
232+
\
233+
KERNEL_FLOAT_INLINE \
234+
static vector_storage<T, 1> call(::T1 v) { \
235+
return {v.x}; \
236+
} \
237+
}; \
238+
\
239+
template<> \
229240
struct into_vector_traits<::T2> { \
230241
using value_type = T; \
231242
using extent_type = extent<2>; \

include/kernel_float/bf16.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ KERNEL_FLOAT_BF16_UNARY_FORWARD(expm1)
144144
} \
145145
namespace detail { \
146146
template<> \
147-
struct map_halfx2<ops::NAME<__nv_bfloat16>> { \
147+
struct map_bfloat16x2<ops::NAME<__nv_bfloat16>> { \
148148
KERNEL_FLOAT_INLINE static __nv_bfloat162 \
149149
call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \
150150
return FUN2(input); \
@@ -324,12 +324,12 @@ KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input))
324324
template<>
325325
struct promote_type<__nv_bfloat16, __half> {
326326
using type = float;
327-
}
327+
};
328328

329329
template<>
330330
struct promote_type<__half, __nv_bfloat16> {
331331
using type = float;
332-
}
332+
};
333333

334334
} // namespace kernel_float
335335

include/kernel_float/vector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct vector: public S {
5454
typename... Rest,
5555
typename = enabled_t<sizeof...(Rest) + 2 == E::size>>
5656
KERNEL_FLOAT_INLINE vector(const A& a, const B& b, const Rest&... rest) :
57-
storage_type {a, b, rest...} {}
57+
storage_type {T(a), T(b), T(rest)...} {}
5858

5959
/**
6060
* Returns the number of elements in this vector.
@@ -316,7 +316,7 @@ template<typename T> using vec8 = vec<T, 8>;
316316
template<typename... Args>
317317
KERNEL_FLOAT_INLINE vec<promote_t<Args...>, sizeof...(Args)> make_vec(Args&&... args) {
318318
using T = promote_t<Args...>;
319-
return vector_storage<T, sizeof...(Args)> {T {args}...};
319+
return vector_storage<T, sizeof...(Args)> {T(args)...};
320320
};
321321

322322
#if defined(__cpp_deduction_guides)

single_include/kernel_float.h

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//================================================================================
22
// this file has been auto-generated, do not modify its contents!
3-
// date: 2023-08-28 14:29:52.760763
4-
// git hash: 31ffbb7ca20f9c4a1c43b37e06c99600a8f15b91
3+
// date: 2023-09-18 17:41:12.641561
4+
// git hash: 64f21903e8049e4a46c53897a167f31174e1a231
55
//================================================================================
66

77
#ifndef KERNEL_FLOAT_MACROS_H
@@ -550,6 +550,17 @@ struct into_vector_traits<aligned_array<T, N, A>> {
550550

551551
#define KERNEL_FLOAT_DEFINE_VECTOR_TYPE(T, T1, T2, T3, T4) \
552552
template<> \
553+
struct into_vector_traits<::T1> { \
554+
using value_type = T; \
555+
using extent_type = extent<1>; \
556+
\
557+
KERNEL_FLOAT_INLINE \
558+
static vector_storage<T, 1> call(::T1 v) { \
559+
return {v.x}; \
560+
} \
561+
}; \
562+
\
563+
template<> \
553564
struct into_vector_traits<::T2> { \
554565
using value_type = T; \
555566
using extent_type = extent<2>; \
@@ -2759,7 +2770,7 @@ struct vector: public S {
27592770
typename... Rest,
27602771
typename = enabled_t<sizeof...(Rest) + 2 == E::size>>
27612772
KERNEL_FLOAT_INLINE vector(const A& a, const B& b, const Rest&... rest) :
2762-
storage_type {a, b, rest...} {}
2773+
storage_type {T(a), T(b), T(rest)...} {}
27632774

27642775
/**
27652776
* Returns the number of elements in this vector.
@@ -3021,7 +3032,7 @@ template<typename T> using vec8 = vec<T, 8>;
30213032
template<typename... Args>
30223033
KERNEL_FLOAT_INLINE vec<promote_t<Args...>, sizeof...(Args)> make_vec(Args&&... args) {
30233034
using T = promote_t<Args...>;
3024-
return vector_storage<T, sizeof...(Args)> {T {args}...};
3035+
return vector_storage<T, sizeof...(Args)> {T(args)...};
30253036
};
30263037

30273038
#if defined(__cpp_deduction_guides)
@@ -3484,7 +3495,7 @@ KERNEL_FLOAT_BF16_UNARY_FORWARD(expm1)
34843495
} \
34853496
namespace detail { \
34863497
template<> \
3487-
struct map_halfx2<ops::NAME<__nv_bfloat16>> { \
3498+
struct map_bfloat16x2<ops::NAME<__nv_bfloat16>> { \
34883499
KERNEL_FLOAT_INLINE static __nv_bfloat162 \
34893500
call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \
34903501
return FUN2(input); \
@@ -3664,12 +3675,12 @@ KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input))
36643675
template<>
36653676
struct promote_type<__nv_bfloat16, __half> {
36663677
using type = float;
3667-
}
3678+
};
36683679

36693680
template<>
36703681
struct promote_type<__half, __nv_bfloat16> {
36713682
using type = float;
3672-
}
3683+
};
36733684

36743685
} // namespace kernel_float
36753686

0 commit comments

Comments
 (0)