Skip to content

Commit 8f31b9f

Browse files
committed
Fix bug in select function causing invalid vector size
1 parent 06e08f5 commit 8f31b9f

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

include/kernel_float/iterate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct concat_impl<V, Vs...> {
225225
template<>
226226
struct concat_impl<> {
227227
using value_type = void;
228-
static constexpr size_t size = 1;
228+
static constexpr size_t size = 0;
229229

230230
template<typename U>
231231
KERNEL_FLOAT_INLINE static void call(U* output) {}

include/kernel_float/vector.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace kernel_float {
2323
*/
2424
template<typename T, typename E, class S>
2525
struct vector: public S {
26+
using self_type = vector<T, E, S>;
2627
using value_type = T;
2728
using extent_type = E;
2829
using storage_type = S;
@@ -221,8 +222,8 @@ struct vector: public S {
221222
* vec<float, 4> vec2 = select(input, indices); // [0, 40, 40, 20]
222223
* ```
223224
*/
224-
template<typename V, typename... Is>
225-
KERNEL_FLOAT_INLINE select_type<V, Is...> select(const Is&... indices) {
225+
template<typename... Is>
226+
KERNEL_FLOAT_INLINE select_type<self_type, Is...> select(const Is&... indices) {
226227
return kernel_float::select(*this, indices...);
227228
}
228229

tests/iterate.cu

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "common.h"
2+
3+
struct select_tests {
4+
template<typename T>
5+
__host__ __device__ void operator()(generator<T> gen) {
6+
T data[8] = {
7+
gen.next(),
8+
gen.next(),
9+
gen.next(),
10+
gen.next(),
11+
gen.next(),
12+
gen.next(),
13+
gen.next(),
14+
gen.next()};
15+
kf::vec<T, 8> x = {data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7]};
16+
17+
// Empty
18+
ASSERT_EQ(select(x), (kf::vec<int, 0>()));
19+
ASSERT_EQ(select(x, kf::vec<int, 0>()), (kf::vec<T, 0>()));
20+
21+
// One element
22+
ASSERT_EQ(select(x, 0), kf::make_vec(data[0]));
23+
ASSERT_EQ(select(x, 5), kf::make_vec(data[5]));
24+
ASSERT_EQ(select(x, 7), kf::make_vec(data[4]));
25+
26+
// Two elements
27+
ASSERT_EQ(select(x, 0, 1), kf::make_vec(data[0], data[1]));
28+
ASSERT_EQ(select(x, 5, 0), kf::make_vec(data[5], data[0]));
29+
ASSERT_EQ(select(x, 6, 7), kf::make_vec(data[6], data[7]));
30+
31+
// Two elements as array
32+
ASSERT_EQ(select(x, kf::make_vec(0, 1)), kf::make_vec(data[0], data[1]));
33+
ASSERT_EQ(select(x, kf::make_vec(5, 0)), kf::make_vec(data[5], data[0]));
34+
ASSERT_EQ(select(x, kf::make_vec(6, 7)), kf::make_vec(data[6], data[7]));
35+
36+
// Three elements
37+
ASSERT_EQ(select(x, kf::make_vec(0, 1), 2), kf::make_vec(data[0], data[1], data[2]));
38+
ASSERT_EQ(select(x, kf::make_vec(5, 0, 7)), kf::make_vec(data[5], data[0], data[7]));
39+
ASSERT_EQ(select(x, 6, kf::make_vec(7, 2)), kf::make_vec(data[6], data[7], data[2]));
40+
41+
// Method of vector
42+
ASSERT_EQ(x.select(), (kf::vec<T, 0>()));
43+
ASSERT_EQ(x.select(4), kf::make_vec(data[4]));
44+
ASSERT_EQ(x.select(4, 2), kf::make_vec(data[4], data[2]));
45+
ASSERT_EQ(x.select(4, 2, 7), kf::make_vec(data[4], data[2], data[7]));
46+
}
47+
};

0 commit comments

Comments
 (0)