Skip to content

Commit ded657c

Browse files
authored
Merge device zip_iterator
Adds a custom device tuple implementation to support zip_iterator in device code. Related PR: #1604
2 parents caa373d + 3f5e359 commit ded657c

File tree

13 files changed

+420
-96
lines changed

13 files changed

+420
-96
lines changed

core/base/iterator_factory.hpp

Lines changed: 313 additions & 54 deletions
Large diffs are not rendered by default.

core/test/base/iterator_factory.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ TYPED_TEST(ZipIterator, IteratorReferenceOperatorSmaller2)
156156

157157
TYPED_TEST(ZipIterator, IncreasingIterator)
158158
{
159+
using gko::get;
159160
using index_type = typename TestFixture::index_type;
160161
using value_type = typename TestFixture::value_type;
161162
std::vector<index_type> vec1{this->reversed_index};
@@ -182,8 +183,8 @@ TYPED_TEST(ZipIterator, IncreasingIterator)
182183
ASSERT_TRUE(increment_pre_2 == increment_post_2);
183184
ASSERT_TRUE(begin == increment_post_test++);
184185
ASSERT_TRUE(begin + 1 == ++increment_pre_test);
185-
ASSERT_TRUE(std::get<0>(*plus_2) == vec1[2]);
186-
ASSERT_TRUE(std::get<1>(*plus_2) == vec2[2]);
186+
ASSERT_TRUE(get<0>(*plus_2) == vec1[2]);
187+
ASSERT_TRUE(get<1>(*plus_2) == vec2[2]);
187188
// check other comparison operators and difference
188189
std::vector<gko::detail::zip_iterator<index_type*, value_type*>> its{
189190
begin,
@@ -257,6 +258,7 @@ TYPED_TEST(ZipIterator, IncompatibleIteratorDeathTest)
257258

258259
TYPED_TEST(ZipIterator, DecreasingIterator)
259260
{
261+
using gko::get;
260262
using index_type = typename TestFixture::index_type;
261263
using value_type = typename TestFixture::value_type;
262264
std::vector<index_type> vec1{this->reversed_index};
@@ -280,13 +282,14 @@ TYPED_TEST(ZipIterator, DecreasingIterator)
280282
ASSERT_TRUE(decrement_pre_2 == decrement_post_2);
281283
ASSERT_TRUE(iter == decrement_post_test--);
282284
ASSERT_TRUE(iter - 1 == --decrement_pre_test);
283-
ASSERT_TRUE(std::get<0>(*minus_2) == vec1[3]);
284-
ASSERT_TRUE(std::get<1>(*minus_2) == vec2[3]);
285+
ASSERT_TRUE(get<0>(*minus_2) == vec1[3]);
286+
ASSERT_TRUE(get<1>(*minus_2) == vec2[3]);
285287
}
286288

287289

288290
TYPED_TEST(ZipIterator, CorrectDereferencing)
289291
{
292+
using gko::get;
290293
using index_type_it = typename TestFixture::index_type;
291294
using value_type_it = typename TestFixture::value_type;
292295
std::vector<index_type_it> vec1{this->reversed_index};
@@ -299,10 +302,10 @@ TYPED_TEST(ZipIterator, CorrectDereferencing)
299302
auto to_test_ref = *(begin + element_to_test);
300303
value_type to_test_pair = to_test_ref; // Testing implicit conversion
301304

302-
ASSERT_TRUE(std::get<0>(to_test_pair) == vec1[element_to_test]);
303-
ASSERT_TRUE(std::get<0>(to_test_pair) == std::get<0>(to_test_ref));
304-
ASSERT_TRUE(std::get<1>(to_test_pair) == vec2[element_to_test]);
305-
ASSERT_TRUE(std::get<1>(to_test_pair) == std::get<1>(to_test_ref));
305+
ASSERT_TRUE(get<0>(to_test_pair) == vec1[element_to_test]);
306+
ASSERT_TRUE(get<0>(to_test_pair) == get<0>(to_test_ref));
307+
ASSERT_TRUE(get<1>(to_test_pair) == vec2[element_to_test]);
308+
ASSERT_TRUE(get<1>(to_test_pair) == get<1>(to_test_ref));
306309
}
307310

308311

omp/distributed/index_map_kernels.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,15 @@ void build_mapping(
5858
auto sort_it = detail::make_zip_iterator(
5959
full_remote_part_ids.begin(), recv_connections_ptr, range_ids.begin());
6060
std::sort(sort_it, sort_it + input_size, [](const auto& a, const auto& b) {
61-
return std::tie(std::get<0>(a), std::get<1>(a)) <
62-
std::tie(std::get<0>(b), std::get<1>(b));
61+
return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b));
6362
});
6463

6564
// get only unique connections
66-
auto unique_end = std::unique(
67-
sort_it, sort_it + input_size, [](const auto& a, const auto& b) {
68-
return std::tie(std::get<0>(a), std::get<1>(a)) ==
69-
std::tie(std::get<0>(b), std::get<1>(b));
70-
});
65+
auto unique_end = std::unique(sort_it, sort_it + input_size,
66+
[](const auto& a, const auto& b) {
67+
return std::tie(get<0>(a), get<1>(a)) ==
68+
std::tie(get<0>(b), get<1>(b));
69+
});
7170
auto unique_size = std::distance(sort_it, unique_end);
7271

7372
remote_global_idxs.resize_and_reset(unique_size);

omp/distributed/partition_helpers_kernels.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ void sort_by_range_start(
2727
range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; });
2828
auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d);
2929
// TODO: use TBB or parallel std with c++17
30-
std::stable_sort(sort_it, sort_it + num_parts,
31-
[](const auto& a, const auto& b) {
32-
return std::get<0>(a) < std::get<0>(b);
33-
});
30+
std::stable_sort(
31+
sort_it, sort_it + num_parts,
32+
[](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); });
3433
}
3534

3635
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(

omp/matrix/csr_kernels.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,9 +1155,8 @@ void sort_by_column_index(std::shared_ptr<const OmpExecutor> exec,
11551155
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
11561156
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
11571157
values + start_row_idx);
1158-
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
1159-
return std::get<0>(t1) < std::get<0>(t2);
1160-
});
1158+
std::sort(it, it + row_nnz,
1159+
[](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); });
11611160
}
11621161
}
11631162

omp/matrix/fbcsr_kernels.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,8 @@ void sort_by_column_index_impl(
398398
std::vector<IndexType> col_permute(nbnz_brow);
399399
std::iota(col_permute.begin(), col_permute.end(), 0);
400400
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
401-
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
402-
return std::get<0>(a) < std::get<0>(b);
403-
});
401+
std::sort(it, it + nbnz_brow,
402+
[](auto a, auto b) { return get<0>(a) < get<0>(b); });
404403

405404
std::vector<ValueType> oldvalues(nbnz_brow * bs2);
406405
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());

omp/multigrid/pgm_kernels.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ void sort_row_major(std::shared_ptr<const DefaultExecutor> exec, size_type nnz,
4343
{
4444
auto it = detail::make_zip_iterator(row_idxs, col_idxs, vals);
4545
std::stable_sort(it, it + nnz, [](auto a, auto b) {
46-
return std::tie(std::get<0>(a), std::get<1>(a)) <
47-
std::tie(std::get<0>(b), std::get<1>(b));
46+
return std::tie(get<0>(a), get<1>(a)) < std::tie(get<0>(b), get<1>(b));
4847
});
4948
}
5049

reference/distributed/partition_helpers_kernels.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ void sort_by_range_start(
2626
auto end_it = detail::make_permute_iterator(
2727
range_start_ends.get_data() + 1, [](const auto i) { return 2 * i; });
2828
auto sort_it = detail::make_zip_iterator(start_it, end_it, part_ids_d);
29-
std::stable_sort(sort_it, sort_it + num_parts,
30-
[](const auto& a, const auto& b) {
31-
return std::get<0>(a) < std::get<0>(b);
32-
});
29+
std::stable_sort(
30+
sort_it, sort_it + num_parts,
31+
[](const auto& a, const auto& b) { return get<0>(a) < get<0>(b); });
3332
}
3433

3534
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
@@ -51,9 +50,9 @@ void check_consecutive_ranges(std::shared_ptr<const DefaultExecutor> exec,
5150
auto range_it = detail::make_zip_iterator(start_it, end_it);
5251

5352
if (num_parts) {
54-
result = std::all_of(
55-
range_it, range_it + num_parts - 1,
56-
[](const auto& r) { return std::get<0>(r) == std::get<1>(r); });
53+
result =
54+
std::all_of(range_it, range_it + num_parts - 1,
55+
[](const auto& r) { return get<0>(r) == get<1>(r); });
5756
} else {
5857
result = true;
5958
}

reference/matrix/csr_kernels.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,9 +1128,8 @@ void sort_by_column_index(std::shared_ptr<const ReferenceExecutor> exec,
11281128
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
11291129
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
11301130
values + start_row_idx);
1131-
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
1132-
return std::get<0>(t1) < std::get<0>(t2);
1133-
});
1131+
std::sort(it, it + row_nnz,
1132+
[](auto t1, auto t2) { return get<0>(t1) < get<0>(t2); });
11341133
}
11351134
}
11361135

reference/matrix/fbcsr_kernels.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,8 @@ void sort_by_column_index_impl(
418418
std::vector<IndexType> col_permute(nbnz_brow);
419419
std::iota(col_permute.begin(), col_permute.end(), 0);
420420
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
421-
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
422-
return std::get<0>(a) < std::get<0>(b);
423-
});
421+
std::sort(it, it + nbnz_brow,
422+
[](auto a, auto b) { return get<0>(a) < get<0>(b); });
424423

425424
std::vector<ValueType> oldvalues(nbnz_brow * bs2);
426425
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());

0 commit comments

Comments
 (0)