Skip to content

Commit e8c4aec

Browse files
committed
Add unit test
1 parent 65d2402 commit e8c4aec

File tree

2 files changed

+207
-1
lines changed

2 files changed

+207
-1
lines changed

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ ConfigureTest(HYPERLOGLOG_TEST
153153
ConfigureTest(BLOOM_FILTER_TEST
154154
bloom_filter/unique_sequence_test.cu
155155
bloom_filter/arrow_policy_test.cu
156-
bloom_filter/variable_cg_test.cu)
156+
bloom_filter/variable_cg_test.cu
157+
bloom_filter/merge_intersect_test.cu)
157158

158159
###################################################################################################
159160
# - roaring_bitmap ---------------------------------------------------------------------------------
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <test_utils.hpp>
18+
19+
#include <cuco/bloom_filter.cuh>
20+
#include <cuco/utility/error.hpp>
21+
22+
#include <cuda/functional>
23+
#include <thrust/device_vector.h>
24+
#include <thrust/iterator/counting_iterator.h>
25+
26+
#include <catch2/catch_template_test_macros.hpp>
27+
#include <catch2/generators/catch_generators.hpp>
28+
29+
#include <cstdint>
30+
#include <exception>
31+
32+
using size_type = int32_t;
33+
34+
template <typename Filter>
35+
void test_merge_intersect(Filter& filter_a,
36+
Filter& filter_b,
37+
Filter const& filter_c,
38+
size_type capacity)
39+
{
40+
using Key = typename Filter::key_type;
41+
42+
size_type num_keys = capacity;
43+
size_type half_keys = capacity / 2;
44+
45+
// Set A: [0, capacity)
46+
auto keys_a_begin = thrust::counting_iterator<Key>{static_cast<Key>(0)};
47+
auto keys_a_end = keys_a_begin + num_keys;
48+
49+
// Set B: [capacity/2, capacity + capacity/2) (50% overlap with A)
50+
auto keys_b_begin = thrust::counting_iterator<Key>{static_cast<Key>(half_keys)};
51+
auto keys_b_end = keys_b_begin + num_keys;
52+
53+
// Intersection: [capacity/2, capacity)
54+
auto keys_intersection_begin = thrust::counting_iterator<Key>{static_cast<Key>(half_keys)};
55+
auto keys_intersection_end = keys_intersection_begin + half_keys;
56+
57+
// Union: [0, capacity + capacity/2)
58+
auto keys_union_begin = thrust::counting_iterator<Key>{static_cast<Key>(0)};
59+
auto keys_union_end = keys_union_begin + num_keys + half_keys;
60+
61+
// Unique A: [0, capacity/2)
62+
auto keys_unique_a_begin = thrust::counting_iterator<Key>{static_cast<Key>(0)};
63+
auto keys_unique_a_end = keys_unique_a_begin + half_keys;
64+
65+
// Unique B: [capacity, capacity + capacity/2)
66+
auto keys_unique_b_begin = thrust::counting_iterator<Key>{static_cast<Key>(num_keys)};
67+
auto keys_unique_b_end = keys_unique_b_begin + half_keys;
68+
69+
// Helper to fill filters
70+
auto refill_filters = [&]() {
71+
filter_a.clear();
72+
filter_a.add(keys_a_begin, keys_a_end);
73+
74+
filter_b.clear();
75+
filter_b.add(keys_b_begin, keys_b_end);
76+
};
77+
78+
// Reusable output vector (sized for largest query: union)
79+
thrust::device_vector<bool> contained(num_keys + half_keys);
80+
81+
SECTION("Merge B into A")
82+
{
83+
refill_filters();
84+
filter_a.merge(filter_b);
85+
86+
// Check A contains all of Union
87+
filter_a.contains(keys_union_begin, keys_union_end, contained.begin());
88+
REQUIRE(cuco::test::all_of(
89+
contained.begin(), contained.begin() + num_keys + half_keys, cuda::std::identity{}));
90+
91+
// Check B is unchanged
92+
filter_b.contains(keys_b_begin, keys_b_end, contained.begin());
93+
REQUIRE(
94+
cuco::test::all_of(contained.begin(), contained.begin() + num_keys, cuda::std::identity{}));
95+
}
96+
97+
SECTION("Intersect B into A")
98+
{
99+
refill_filters();
100+
filter_a.intersect(filter_b);
101+
102+
// Check A contains Intersection
103+
filter_a.contains(keys_intersection_begin, keys_intersection_end, contained.begin());
104+
REQUIRE(
105+
cuco::test::all_of(contained.begin(), contained.begin() + half_keys, cuda::std::identity{}));
106+
107+
// Check A does NOT contain Unique A (approximate)
108+
// We expect none_of, but due to false positives, we might get some.
109+
// However, for this test configuration, we expect 0 false positives if the filter is
110+
// reasonably sized.
111+
filter_a.contains(keys_unique_a_begin, keys_unique_a_end, contained.begin());
112+
REQUIRE(
113+
cuco::test::none_of(contained.begin(), contained.begin() + half_keys, cuda::std::identity{}));
114+
115+
// Check A does NOT contain Unique B
116+
filter_a.contains(keys_unique_b_begin, keys_unique_b_end, contained.begin());
117+
REQUIRE(
118+
cuco::test::none_of(contained.begin(), contained.begin() + half_keys, cuda::std::identity{}));
119+
}
120+
121+
SECTION("Merge empty filter into A")
122+
{
123+
filter_a.clear();
124+
filter_a.add(keys_a_begin, keys_a_end);
125+
filter_b.clear(); // B is empty
126+
127+
filter_a.merge(filter_b);
128+
129+
// A should still contain all of Set A
130+
filter_a.contains(keys_a_begin, keys_a_end, contained.begin());
131+
REQUIRE(
132+
cuco::test::all_of(contained.begin(), contained.begin() + num_keys, cuda::std::identity{}));
133+
}
134+
135+
SECTION("Intersect empty filter into A")
136+
{
137+
filter_a.clear();
138+
filter_a.add(keys_a_begin, keys_a_end);
139+
filter_b.clear(); // B is empty
140+
141+
filter_a.intersect(filter_b);
142+
143+
// A should now be empty (intersection with empty set)
144+
filter_a.contains(keys_a_begin, keys_a_end, contained.begin());
145+
REQUIRE(
146+
cuco::test::none_of(contained.begin(), contained.begin() + num_keys, cuda::std::identity{}));
147+
}
148+
149+
SECTION("Mismatched block counts")
150+
{
151+
// also test with custom stream
152+
cudaStream_t stream;
153+
cudaStreamCreate(&stream);
154+
REQUIRE_THROWS_AS(filter_a.merge(filter_c, stream), cuco::logic_error);
155+
REQUIRE_THROWS_AS(filter_a.intersect(filter_c, stream), cuco::logic_error);
156+
cudaStreamDestroy(stream);
157+
}
158+
}
159+
160+
TEMPLATE_TEST_CASE_SIG(
161+
"bloom_filter merge and intersect tests",
162+
"",
163+
((class Key, class Policy), Key, Policy),
164+
(int32_t, cuco::default_filter_policy<cuco::xxhash_64<int32_t>, uint32_t, 1>),
165+
(int32_t, cuco::default_filter_policy<cuco::xxhash_64<int32_t>, uint32_t, 8>),
166+
(int64_t, cuco::default_filter_policy<cuco::xxhash_64<int64_t>, uint64_t, 1>),
167+
(int64_t, cuco::default_filter_policy<cuco::xxhash_64<int64_t>, uint64_t, 8>))
168+
{
169+
using filter_type =
170+
cuco::bloom_filter<Key, cuco::extent<size_t>, cuda::thread_scope_device, Policy>;
171+
constexpr size_type capacity{1000};
172+
173+
uint32_t pattern_bits = Policy::words_per_block + GENERATE(0, 1);
174+
175+
// some parameter combinations might be invalid so we skip them
176+
try {
177+
[[maybe_unused]] auto policy = Policy{pattern_bits};
178+
} catch (std::exception const& e) {
179+
SKIP(e.what());
180+
}
181+
182+
auto filter_a = filter_type{capacity, {}, {pattern_bits}};
183+
auto filter_b = filter_type{capacity, {}, {pattern_bits}};
184+
auto filter_c = filter_type{static_cast<size_t>(capacity) * 2, {}, {pattern_bits}};
185+
186+
test_merge_intersect(filter_a, filter_b, filter_c, capacity);
187+
}
188+
189+
TEMPLATE_TEST_CASE_SIG("bloom_filter merge and intersect arrow tests",
190+
"",
191+
((class Key, class Policy), Key, Policy),
192+
(int32_t, cuco::arrow_filter_policy<int32_t>),
193+
(int64_t, cuco::arrow_filter_policy<int64_t>),
194+
(float, cuco::arrow_filter_policy<float>))
195+
{
196+
using filter_type =
197+
cuco::bloom_filter<Key, cuco::extent<size_t>, cuda::thread_scope_device, Policy>;
198+
constexpr size_type capacity{1000}; // Must match capacity used in helper logic
199+
200+
auto filter_a = filter_type{capacity};
201+
auto filter_b = filter_type{capacity};
202+
auto filter_c = filter_type{static_cast<size_t>(capacity) * 2};
203+
204+
test_merge_intersect(filter_a, filter_b, filter_c, capacity);
205+
}

0 commit comments

Comments
 (0)