Skip to content

Commit b29b608

Browse files
Add a bitset validatation test for cuco::arrow_filter_policy (#633)
This PR adds a tests to validate the bitset from inserting specific keys to a `cuco::bloom_filter` with `cuco::arrow_filter_policy` against the one generated by inserting the same keys to the implementation in Arrow. Related to #625. Part of rapidsai/cudf#17164. Reference bitset gen with arrow here: https://godbolt.org/z/ebdddezbP --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 69817e2 commit b29b608

File tree

2 files changed

+167
-1
lines changed

2 files changed

+167
-1
lines changed

tests/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,6 @@ ConfigureTest(HYPERLOGLOG_TEST
142142
###################################################################################################
143143
# - bloom_filter ----------------------------------------------------------------------------------
144144
ConfigureTest(BLOOM_FILTER_TEST
145-
bloom_filter/unique_sequence_test.cu)
145+
bloom_filter/unique_sequence_test.cu
146+
bloom_filter/arrow_policy_test.cu
147+
)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <test_utils.hpp>
18+
19+
#include <cuco/bloom_filter.cuh>
20+
21+
#include <cuda/functional>
22+
#include <thrust/device_vector.h>
23+
#include <thrust/functional.h>
24+
25+
#include <catch2/catch_template_test_macros.hpp>
26+
27+
#include <random>
28+
#include <type_traits>
29+
30+
namespace {
31+
32+
template <typename Key>
33+
thrust::device_vector<uint32_t> get_arrow_filter_reference_bitset()
34+
{
35+
static std::vector<thrust::device_vector<uint32_t>> const reference_bitsets{
36+
{4294752255,
37+
928963967,
38+
4227333887,
39+
3183462382,
40+
3892030683,
41+
3481206270,
42+
3513757613,
43+
3220961761,
44+
3186616955,
45+
4026531705,
46+
4110408887,
47+
804913147,
48+
1039007726,
49+
4286569403,
50+
2675948542,
51+
3688689479}, // type = int32, blocks = 2, num_keys = 100
52+
{2290897413, 3368027184, 2432735301, 2013315170, 610406792, 35787348, 43061541,
53+
1145143906, 238486532, 2840527950, 241188878, 624061504, 759830680, 184694210,
54+
2282459916, 3232258264, 285316692, 3284142851, 2760958614, 2974341265, 38749317,
55+
2655160577, 2193666087, 261196816, 411328595, 5391621, 2308014147, 2550892738,
56+
1224755395, 1396835974, 3227911200, 307324929}, // type = int64, blocks = 4, num_keys = 50
57+
{3037098621, 1001208422, 3070541682, 3611620780, 372254302, 2869772027, 2629135999,
58+
3332804862, 2832966981, 1225184253, 1315442262, 211922492, 1020510327, 2725704195,
59+
2909038118, 2783622989, 4214109798, 535934391, 2385459605, 4109595381, 3219664733,
60+
3164400602, 1995984498, 2917029602, 3047576211, 2212973933, 1672737343, 300902378,
61+
3000318461, 1561320274, 2710202091, 3067275349, 2734901244, 2638172076, 3669981206,
62+
3719000395, 793729452, 2258222966, 4111863618, 2391109497, 240119500, 855317864,
63+
2893522276, 1103034386, 738173080, 4098968587, 1271241025, 499361504, 4174530401,
64+
3259956170, 3823469907, 578271374, 3168397042, 3890816473, 431898609, 1583427570,
65+
1835797371, 2078281027, 2741410265, 2639785266, 3422606831, 1589476610, 3972396492,
66+
3611525326} // type = float, blocks = 8, num_keys = 200
67+
};
68+
69+
if constexpr (std::is_same_v<Key, int32_t>) {
70+
return reference_bitsets[0]; // int32
71+
} else if constexpr (std::is_same_v<Key, int64_t>) {
72+
return reference_bitsets[1]; // int64
73+
} else if constexpr (std::is_same_v<Key, float>) {
74+
return reference_bitsets[2]; // float
75+
} else {
76+
throw std::invalid_argument("Reference bitsets available for int32, int64, float only.\n\n");
77+
}
78+
}
79+
80+
template <typename Key>
81+
std::pair<size_t, size_t> get_arrow_filter_test_settings()
82+
{
83+
static std::vector<std::pair<size_t, size_t>> const test_settings = {
84+
{2, 100}, // type = int32, blocks = 2, num_keys = 100
85+
{4, 50}, // type = int64, blocks = 4, num_keys = 50
86+
{8, 200} // type = float, blocks = 8, num_keys = 200
87+
};
88+
89+
if constexpr (std::is_same_v<Key, int32_t>) {
90+
return test_settings[0]; // int32
91+
} else if constexpr (std::is_same_v<Key, int64_t>) {
92+
return test_settings[1]; // int64
93+
} else if constexpr (std::is_same_v<Key, float>) {
94+
return test_settings[2]; // float
95+
} else {
96+
throw std::invalid_argument("Test settings available for int32, int64, float only.\n\n");
97+
}
98+
}
99+
100+
template <typename Key>
101+
std::vector<Key> random_values(size_t size)
102+
{
103+
std::vector<Key> values(size);
104+
105+
using uniform_distribution =
106+
typename std::conditional_t<std::is_same_v<Key, bool>,
107+
std::bernoulli_distribution,
108+
std::conditional_t<std::is_floating_point_v<Key>,
109+
std::uniform_real_distribution<Key>,
110+
std::uniform_int_distribution<Key>>>;
111+
112+
static constexpr auto seed = 0xf00d;
113+
static std::mt19937 engine{seed};
114+
static uniform_distribution dist{};
115+
std::generate_n(values.begin(), size, [&]() { return Key{dist(engine)}; });
116+
117+
return values;
118+
}
119+
120+
} // namespace
121+
122+
template <typename Filter>
123+
void test_filter_bitset(Filter& filter, size_t num_keys)
124+
{
125+
using key_type = typename Filter::key_type;
126+
using word_type = typename Filter::word_type;
127+
128+
// Generate keys
129+
auto const h_keys = random_values<key_type>(num_keys);
130+
thrust::device_vector<key_type> d_keys(h_keys.begin(), h_keys.end());
131+
132+
// Insert to the bloom filter
133+
filter.add(d_keys.begin(), d_keys.begin() + num_keys);
134+
135+
// Get reference words device_vector
136+
auto const reference_words = get_arrow_filter_reference_bitset<key_type>();
137+
138+
// Number of words in the filter
139+
auto const num_words = filter.block_extent() * filter.words_per_block;
140+
141+
// Get the bitset
142+
thrust::device_vector<word_type> filter_words(filter.data(), filter.data() + num_words);
143+
144+
REQUIRE(cuco::test::equal(
145+
filter_words.begin(),
146+
filter_words.end(),
147+
reference_words.begin(),
148+
cuda::proclaim_return_type<bool>([] __device__(auto const& filter_word, auto const& ref_word) {
149+
return filter_word == ref_word;
150+
})));
151+
}
152+
153+
TEMPLATE_TEST_CASE_SIG(
154+
"Arrow filter policy bitset validation", "", (class Key), (int32_t), (int64_t), (float))
155+
{
156+
// Get test settings
157+
auto const [sub_filters, num_keys] = get_arrow_filter_test_settings<Key>();
158+
159+
using policy_type = cuco::arrow_filter_policy<Key>;
160+
cuco::bloom_filter<Key, cuco::extent<size_t>, cuda::thread_scope_device, policy_type> filter{
161+
sub_filters};
162+
163+
test_filter_bitset(filter, num_keys);
164+
}

0 commit comments

Comments
 (0)