Skip to content

Commit 2a51e52

Browse files
committed
Add ability for KERNEL_FLOAT_TILING_FOR_IMPL to accept two or three arguments
1 parent aaf8645 commit 2a51e52

File tree

4 files changed

+49
-24
lines changed

4 files changed

+49
-24
lines changed

include/kernel_float/macros.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,9 @@
4444
} while (0)
4545
#define KERNEL_FLOAT_UNREACHABLE __builtin_unreachable()
4646

47+
// Somet utility macros
48+
#define KERNEL_FLOAT_CONCAT_IMPL(A, B) A##B
49+
#define KERNEL_FLOAT_CONCAT(A, B) KERNEL_FLOAT_CONCAT_IMPL(A, B)
50+
#define KERNEL_FLOAT_CALL(F, ...) F(__VA_ARGS__)
51+
4752
#endif //KERNEL_FLOAT_MACROS_H

include/kernel_float/meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ template<typename T>
7777
using decay_t = typename detail::decay_impl<T>::type;
7878

7979
template<typename A, typename B>
80-
struct promote_type;
80+
struct promote_type {};
8181

8282
template<typename T>
8383
struct promote_type<T, T> {

include/kernel_float/tiling.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,11 +504,20 @@ template<size_t TileDim, size_t BlockDim, typename D = dist::cyclic, typename In
504504
using tiling_1d = tiling<tile_size<TileDim>, block_size<BlockDim>, distributions<D>, IndexType>;
505505

506506
// clang-format off
507-
#define KERNEL_FLOAT_TILING_FOR(TILING_VARIABLE__, INDEX_VARIABLE__, POINT_VARIABLE__) \
508-
_Pragma("unroll") \
509-
for (::std::size_t INDEX_VARIABLE__ = 0; INDEX_VARIABLE__ < TILING_VARIABLE__.size(); INDEX_VARIABLE__++) \
510-
if (typename decltype(TILING_VARIABLE__)::point_type POINT_VARIABLE__ = TILING_VARIABLE__.at(INDEX_VARIABLE__); \
511-
TILING_VARIABLE__.is_present(INDEX_VARIABLE__))
507+
#define KERNEL_FLOAT_TILING_FOR_IMPL1(ITER_VAR, TILING, POINT_VAR, _) \
508+
_Pragma("unroll") \
509+
for (size_t ITER_VAR = 0; ITER_VAR < (TILING).size(); ITER_VAR++) \
510+
if (POINT_VAR = (TILING).at(ITER_VAR); (TILING).is_present(ITER_VAR)) \
511+
512+
#define KERNEL_FLOAT_TILING_FOR_IMPL2(ITER_VAR, TILING, INDEX_VAR, POINT_VAR) \
513+
KERNEL_FLOAT_TILING_FOR_IMPL1(ITER_VAR, TILING, POINT_VAR, _) \
514+
if (INDEX_VAR = ITER_VAR; true)
515+
516+
#define KERNEL_FLOAT_TILING_FOR_IMPL(ITER_VAR, TILING, A, B, N, ...) \
517+
KERNEL_FLOAT_CALL(KERNEL_FLOAT_CONCAT(KERNEL_FLOAT_TILING_FOR_IMPL, N), ITER_VAR, TILING, A, B)
518+
519+
#define KERNEL_FLOAT_TILING_FOR(...) \
520+
KERNEL_FLOAT_TILING_FOR_IMPL(KERNEL_FLOAT_CONCAT(__tiling_index_variable__, __LINE__), __VA_ARGS__, 2, 1)
512521
// clang-format on
513522

514523
} // namespace kernel_float

tests/tiling.cu

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
struct basic_tiling_test {
55
template<typename T>
66
__host__ __device__ void operator()(generator<T> gen) {
7-
auto tiling = kf::tiling<
7+
using TestTiling = kf::tiling<
88
kf::tile_size<8, 8>,
99
kf::block_size<2, 4>,
10-
kf::distributions<kf::dist::cyclic, kf::dist::blocked>>(dim3(1, 2, 0));
10+
kf::distributions<kf::dist::cyclic, kf::dist::blocked> //
11+
>;
12+
auto tiling = TestTiling(dim3(1, 2, 0));
1113

12-
ASSERT_EQ(tiling.size(), size_t(8));
14+
ASSERT_EQ(TestTiling::size(), size_t(8));
1315

1416
ASSERT_EQ(
1517
tiling.local_points(),
@@ -24,7 +26,6 @@ struct basic_tiling_test {
2426
kf::make_vec(7, 6)));
2527

2628
ASSERT_EQ(tiling.local_points(0), kf::make_vec(1, 3, 5, 7, 1, 3, 5, 7));
27-
2829
ASSERT_EQ(tiling.local_points(1), kf::make_vec(2, 2, 2, 2, 6, 6, 6, 6));
2930

3031
ASSERT_EQ(tiling.at(0), kf::make_vec(1, 2));
@@ -44,6 +45,7 @@ struct basic_tiling_test {
4445
ASSERT_EQ(
4546
tiling.local_mask(),
4647
kf::make_vec(true, true, true, true, true, true, true, true));
48+
ASSERT_EQ(TestTiling::all_present(), true);
4749
ASSERT_EQ(tiling.is_present(0), true);
4850
ASSERT_EQ(tiling.is_present(1), true);
4951
ASSERT_EQ(tiling.is_present(2), true);
@@ -54,19 +56,16 @@ struct basic_tiling_test {
5456
ASSERT_EQ(tiling.thread_index(2), 0);
5557
ASSERT_EQ(tiling.thread_index(), kf::make_vec(1, 2));
5658

57-
ASSERT_EQ(tiling.block_size(0), 2);
58-
ASSERT_EQ(tiling.block_size(1), 4);
59-
ASSERT_EQ(tiling.block_size(2), 1);
60-
ASSERT_EQ(tiling.block_size(), kf::make_vec(2, 4));
61-
62-
ASSERT_EQ(tiling.tile_size(0), 8);
63-
ASSERT_EQ(tiling.tile_size(1), 8);
64-
ASSERT_EQ(tiling.tile_size(2), 1);
65-
ASSERT_EQ(tiling.tile_size(), kf::make_vec(8, 8));
59+
ASSERT_EQ(TestTiling::block_size(0), 2);
60+
ASSERT_EQ(TestTiling::block_size(1), 4);
61+
ASSERT_EQ(TestTiling::block_size(2), 1);
62+
ASSERT_EQ(TestTiling::block_size(), kf::make_vec(2, 4));
6663

67-
ASSERT_EQ(tiling.size(), size_t(8));
64+
ASSERT_EQ(TestTiling::tile_size(0), 8);
65+
ASSERT_EQ(TestTiling::tile_size(1), 8);
66+
ASSERT_EQ(TestTiling::tile_size(2), 1);
67+
ASSERT_EQ(TestTiling::tile_size(), kf::make_vec(8, 8));
6868

69-
size_t counter = 0;
7069
const int points[8][2] = {
7170
{1, 2},
7271
{3, 2},
@@ -78,14 +77,26 @@ struct basic_tiling_test {
7877
{7, 6},
7978
};
8079

81-
KERNEL_FLOAT_TILING_FOR(tiling, i, point) {
82-
ASSERT_EQ(counter, i);
80+
size_t counter = 0;
81+
KERNEL_FLOAT_TILING_FOR(tiling, auto point) {
82+
ASSERT(counter < 8);
83+
ASSERT_EQ(point[0], points[counter][0]);
84+
ASSERT_EQ(point[1], points[counter][1]);
85+
counter++;
86+
}
87+
88+
ASSERT(counter == 8);
89+
90+
counter = 0;
91+
KERNEL_FLOAT_TILING_FOR(tiling, int i, auto point) {
92+
ASSERT(counter < 8);
93+
ASSERT_EQ(counter, size_t(i));
8394
ASSERT_EQ(point[0], points[i][0]);
8495
ASSERT_EQ(point[1], points[i][1]);
8596
counter++;
8697
}
8798

88-
ASSERT_EQ(counter, size_t(8));
99+
ASSERT(counter == 8);
89100
}
90101
};
91102

0 commit comments

Comments
 (0)