Skip to content

Commit 5db1066

Browse files
committed
Improve multiset retrieve test
1 parent 235963d commit 5db1066

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

tests/static_multiset/retrieve_test.cu

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@
2424
#include <thrust/functional.h>
2525
#include <thrust/iterator/constant_iterator.h>
2626
#include <thrust/iterator/counting_iterator.h>
27-
#include <thrust/random.h>
28-
#include <thrust/sequence.h>
29-
#include <thrust/shuffle.h>
27+
#include <thrust/iterator/transform_iterator.h>
3028
#include <thrust/sort.h>
31-
#include <thrust/transform.h>
3229

3330
#include <catch2/catch_template_test_macros.hpp>
3431

@@ -47,32 +44,28 @@ void test_multiplicity(Container& container, std::size_t num_keys, std::size_t m
4744
auto const num_actual_keys = num_unique_keys * multiplicity;
4845
REQUIRE(num_actual_keys <= num_keys);
4946

50-
thrust::device_vector<key_type> input_keys(num_actual_keys);
5147
thrust::device_vector<key_type> probed_keys(num_actual_keys);
5248
thrust::device_vector<key_type> matched_keys(num_actual_keys);
5349

54-
thrust::transform(thrust::counting_iterator<key_type>(0),
55-
thrust::counting_iterator<key_type>(num_actual_keys),
56-
input_keys.begin(),
57-
cuda::proclaim_return_type<key_type>([multiplicity] __device__(auto const& i) {
58-
return static_cast<key_type>(i / multiplicity);
59-
}));
60-
thrust::shuffle(input_keys.begin(), input_keys.end(), thrust::default_random_engine{});
50+
auto const keys_begin = thrust::make_transform_iterator(
51+
thrust::counting_iterator<key_type>(0),
52+
cuda::proclaim_return_type<key_type>([multiplicity] __device__(auto const& i) {
53+
return static_cast<key_type>(i / multiplicity);
54+
}));
6155

62-
container.insert(input_keys.begin(), input_keys.end());
56+
container.insert(keys_begin, keys_begin + num_actual_keys);
6357
REQUIRE(container.size() == num_actual_keys);
6458

6559
SECTION("All inserted keys should be contained.")
6660
{
6761
auto const [probed_end, matched_end] = container.retrieve(
68-
input_keys.begin(), input_keys.end(), probed_keys.begin(), matched_keys.begin());
69-
thrust::sort(input_keys.begin(), input_keys.end());
62+
keys_begin, keys_begin + num_actual_keys, probed_keys.begin(), matched_keys.begin());
7063
thrust::sort(probed_keys.begin(), probed_end);
7164
thrust::sort(matched_keys.begin(), matched_end);
7265
REQUIRE(cuco::test::equal(
73-
probed_keys.begin(), probed_keys.end(), input_keys.begin(), thrust::equal_to<key_type>{}));
66+
probed_keys.begin(), probed_keys.end(), keys_begin, thrust::equal_to<key_type>{}));
7467
REQUIRE(cuco::test::equal(
75-
matched_keys.begin(), matched_keys.end(), input_keys.begin(), thrust::equal_to<key_type>{}));
68+
matched_keys.begin(), matched_keys.end(), keys_begin, thrust::equal_to<key_type>{}));
7669
}
7770
}
7871

@@ -84,18 +77,16 @@ void test_outer(Container& container, std::size_t num_keys)
8477

8578
container.clear();
8679

87-
thrust::device_vector<key_type> insert_keys(num_keys);
88-
thrust::sequence(insert_keys.begin(), insert_keys.end(), 0);
89-
thrust::device_vector<key_type> query_keys(num_keys * 2ull);
90-
thrust::sequence(query_keys.begin(), query_keys.end(), 0);
80+
auto const keys_begin = thrust::counting_iterator<key_type>{0};
81+
auto const query_size = num_keys * 2ull;
9182

9283
thrust::device_vector<key_type> probed_keys(num_keys * 2ull);
9384
thrust::device_vector<key_type> matched_keys(num_keys * 2ull);
9485

9586
SECTION("Non-inserted keys should output sentinels.")
9687
{
97-
auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(),
98-
query_keys.end(),
88+
auto const [probed_end, matched_end] = container.retrieve_outer(keys_begin,
89+
keys_begin + query_size,
9990
container.key_eq(),
10091
container.hash_function(),
10192
probed_keys.begin(),
@@ -112,12 +103,12 @@ void test_outer(Container& container, std::size_t num_keys)
112103
})));
113104
}
114105

115-
container.insert(insert_keys.begin(), insert_keys.end());
106+
container.insert(keys_begin, keys_begin + num_keys);
116107

117108
SECTION("All inserted keys should be contained.")
118109
{
119-
auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(),
120-
query_keys.end(),
110+
auto const [probed_end, matched_end] = container.retrieve_outer(keys_begin,
111+
keys_begin + query_size,
121112
container.key_eq(),
122113
container.hash_function(),
123114
probed_keys.begin(),
@@ -126,10 +117,10 @@ void test_outer(Container& container, std::size_t num_keys)
126117
probed_keys.begin(), probed_end, matched_keys.begin(), thrust::less<key_type>());
127118

128119
REQUIRE(cuco::test::equal(
129-
probed_keys.begin(), probed_keys.end(), query_keys.begin(), thrust::equal_to<key_type>{}));
120+
probed_keys.begin(), probed_keys.end(), keys_begin, thrust::equal_to<key_type>{}));
130121
REQUIRE(cuco::test::equal(matched_keys.begin(),
131122
matched_keys.begin() + num_keys,
132-
insert_keys.begin(),
123+
keys_begin,
133124
thrust::equal_to<key_type>{}));
134125
REQUIRE(cuco::test::all_of(
135126
matched_keys.begin() + num_keys,

0 commit comments

Comments
 (0)