2424#include < thrust/functional.h>
2525#include < thrust/iterator/constant_iterator.h>
2626#include < thrust/iterator/counting_iterator.h>
27- #include < thrust/random.h>
28- #include < thrust/sequence.h>
29- #include < thrust/shuffle.h>
27+ #include < thrust/iterator/transform_iterator.h>
3028#include < thrust/sort.h>
31- #include < thrust/transform.h>
3229
3330#include < catch2/catch_template_test_macros.hpp>
3431
@@ -47,32 +44,28 @@ void test_multiplicity(Container& container, std::size_t num_keys, std::size_t m
4744 auto const num_actual_keys = num_unique_keys * multiplicity;
4845 REQUIRE (num_actual_keys <= num_keys);
4946
50- thrust::device_vector<key_type> input_keys (num_actual_keys);
5147 thrust::device_vector<key_type> probed_keys (num_actual_keys);
5248 thrust::device_vector<key_type> matched_keys (num_actual_keys);
5349
54- thrust::transform (thrust::counting_iterator<key_type>(0 ),
55- thrust::counting_iterator<key_type>(num_actual_keys),
56- input_keys.begin (),
57- cuda::proclaim_return_type<key_type>([multiplicity] __device__ (auto const & i) {
58- return static_cast <key_type>(i / multiplicity);
59- }));
60- thrust::shuffle (input_keys.begin (), input_keys.end (), thrust::default_random_engine{});
50+ auto const keys_begin = thrust::make_transform_iterator (
51+ thrust::counting_iterator<key_type>(0 ),
52+ cuda::proclaim_return_type<key_type>([multiplicity] __device__ (auto const & i) {
53+ return static_cast <key_type>(i / multiplicity);
54+ }));
6155
62- container.insert (input_keys. begin (), input_keys. end () );
56+ container.insert (keys_begin, keys_begin + num_actual_keys );
6357 REQUIRE (container.size () == num_actual_keys);
6458
6559 SECTION (" All inserted keys should be contained." )
6660 {
6761 auto const [probed_end, matched_end] = container.retrieve (
68- input_keys.begin (), input_keys.end (), probed_keys.begin (), matched_keys.begin ());
69- thrust::sort (input_keys.begin (), input_keys.end ());
62+ keys_begin, keys_begin + num_actual_keys, probed_keys.begin (), matched_keys.begin ());
7063 thrust::sort (probed_keys.begin (), probed_end);
7164 thrust::sort (matched_keys.begin (), matched_end);
7265 REQUIRE (cuco::test::equal (
73- probed_keys.begin (), probed_keys.end (), input_keys. begin () , thrust::equal_to<key_type>{}));
66+ probed_keys.begin (), probed_keys.end (), keys_begin , thrust::equal_to<key_type>{}));
7467 REQUIRE (cuco::test::equal (
75- matched_keys.begin (), matched_keys.end (), input_keys. begin () , thrust::equal_to<key_type>{}));
68+ matched_keys.begin (), matched_keys.end (), keys_begin , thrust::equal_to<key_type>{}));
7669 }
7770}
7871
@@ -84,18 +77,16 @@ void test_outer(Container& container, std::size_t num_keys)
8477
8578 container.clear ();
8679
87- thrust::device_vector<key_type> insert_keys (num_keys);
88- thrust::sequence (insert_keys.begin (), insert_keys.end (), 0 );
89- thrust::device_vector<key_type> query_keys (num_keys * 2ull );
90- thrust::sequence (query_keys.begin (), query_keys.end (), 0 );
80+ auto const keys_begin = thrust::counting_iterator<key_type>{0 };
81+ auto const query_size = num_keys * 2ull ;
9182
9283 thrust::device_vector<key_type> probed_keys (num_keys * 2ull );
9384 thrust::device_vector<key_type> matched_keys (num_keys * 2ull );
9485
9586 SECTION (" Non-inserted keys should output sentinels." )
9687 {
97- auto const [probed_end, matched_end] = container.retrieve_outer (query_keys. begin () ,
98- query_keys. end () ,
88+ auto const [probed_end, matched_end] = container.retrieve_outer (keys_begin ,
89+ keys_begin + query_size ,
9990 container.key_eq (),
10091 container.hash_function (),
10192 probed_keys.begin (),
@@ -112,12 +103,12 @@ void test_outer(Container& container, std::size_t num_keys)
112103 })));
113104 }
114105
115- container.insert (insert_keys. begin (), insert_keys. end () );
106+ container.insert (keys_begin, keys_begin + num_keys );
116107
117108 SECTION (" All inserted keys should be contained." )
118109 {
119- auto const [probed_end, matched_end] = container.retrieve_outer (query_keys. begin () ,
120- query_keys. end () ,
110+ auto const [probed_end, matched_end] = container.retrieve_outer (keys_begin ,
111+ keys_begin + query_size ,
121112 container.key_eq (),
122113 container.hash_function (),
123114 probed_keys.begin (),
@@ -126,10 +117,10 @@ void test_outer(Container& container, std::size_t num_keys)
126117 probed_keys.begin (), probed_end, matched_keys.begin (), thrust::less<key_type>());
127118
128119 REQUIRE (cuco::test::equal (
129- probed_keys.begin (), probed_keys.end (), query_keys. begin () , thrust::equal_to<key_type>{}));
120+ probed_keys.begin (), probed_keys.end (), keys_begin , thrust::equal_to<key_type>{}));
130121 REQUIRE (cuco::test::equal (matched_keys.begin (),
131122 matched_keys.begin () + num_keys,
132- insert_keys. begin () ,
123+ keys_begin ,
133124 thrust::equal_to<key_type>{}));
134125 REQUIRE (cuco::test::all_of (
135126 matched_keys.begin () + num_keys,
0 commit comments