|
| 1 | +/* |
| 2 | + * Copyright (c) 2024, NVIDIA CORPORATION. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include <test_utils.hpp> |
| 18 | + |
| 19 | +#include <cuco/bloom_filter.cuh> |
| 20 | + |
| 21 | +#include <cuda/functional> |
| 22 | +#include <thrust/device_vector.h> |
| 23 | +#include <thrust/functional.h> |
| 24 | + |
| 25 | +#include <catch2/catch_template_test_macros.hpp> |
| 26 | + |
| 27 | +#include <random> |
| 28 | +#include <type_traits> |
| 29 | + |
| 30 | +namespace { |
| 31 | + |
| 32 | +template <typename Key> |
| 33 | +thrust::device_vector<uint32_t> get_arrow_filter_reference_bitset() |
| 34 | +{ |
| 35 | + static std::vector<thrust::device_vector<uint32_t>> const reference_bitsets{ |
| 36 | + {4294752255, |
| 37 | + 928963967, |
| 38 | + 4227333887, |
| 39 | + 3183462382, |
| 40 | + 3892030683, |
| 41 | + 3481206270, |
| 42 | + 3513757613, |
| 43 | + 3220961761, |
| 44 | + 3186616955, |
| 45 | + 4026531705, |
| 46 | + 4110408887, |
| 47 | + 804913147, |
| 48 | + 1039007726, |
| 49 | + 4286569403, |
| 50 | + 2675948542, |
| 51 | + 3688689479}, // type = int32, blocks = 2, num_keys = 100 |
| 52 | + {2290897413, 3368027184, 2432735301, 2013315170, 610406792, 35787348, 43061541, |
| 53 | + 1145143906, 238486532, 2840527950, 241188878, 624061504, 759830680, 184694210, |
| 54 | + 2282459916, 3232258264, 285316692, 3284142851, 2760958614, 2974341265, 38749317, |
| 55 | + 2655160577, 2193666087, 261196816, 411328595, 5391621, 2308014147, 2550892738, |
| 56 | + 1224755395, 1396835974, 3227911200, 307324929}, // type = int64, blocks = 4, num_keys = 50 |
| 57 | + {3037098621, 1001208422, 3070541682, 3611620780, 372254302, 2869772027, 2629135999, |
| 58 | + 3332804862, 2832966981, 1225184253, 1315442262, 211922492, 1020510327, 2725704195, |
| 59 | + 2909038118, 2783622989, 4214109798, 535934391, 2385459605, 4109595381, 3219664733, |
| 60 | + 3164400602, 1995984498, 2917029602, 3047576211, 2212973933, 1672737343, 300902378, |
| 61 | + 3000318461, 1561320274, 2710202091, 3067275349, 2734901244, 2638172076, 3669981206, |
| 62 | + 3719000395, 793729452, 2258222966, 4111863618, 2391109497, 240119500, 855317864, |
| 63 | + 2893522276, 1103034386, 738173080, 4098968587, 1271241025, 499361504, 4174530401, |
| 64 | + 3259956170, 3823469907, 578271374, 3168397042, 3890816473, 431898609, 1583427570, |
| 65 | + 1835797371, 2078281027, 2741410265, 2639785266, 3422606831, 1589476610, 3972396492, |
| 66 | + 3611525326} // type = float, blocks = 8, num_keys = 200 |
| 67 | + }; |
| 68 | + |
| 69 | + if constexpr (std::is_same_v<Key, int32_t>) { |
| 70 | + return reference_bitsets[0]; // int32 |
| 71 | + } else if constexpr (std::is_same_v<Key, int64_t>) { |
| 72 | + return reference_bitsets[1]; // int64 |
| 73 | + } else if constexpr (std::is_same_v<Key, float>) { |
| 74 | + return reference_bitsets[2]; // float |
| 75 | + } else { |
| 76 | + throw std::invalid_argument("Reference bitsets available for int32, int64, float only.\n\n"); |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +template <typename Key> |
| 81 | +std::pair<size_t, size_t> get_arrow_filter_test_settings() |
| 82 | +{ |
| 83 | + static std::vector<std::pair<size_t, size_t>> const test_settings = { |
| 84 | + {2, 100}, // type = int32, blocks = 2, num_keys = 100 |
| 85 | + {4, 50}, // type = int64, blocks = 4, num_keys = 50 |
| 86 | + {8, 200} // type = float, blocks = 8, num_keys = 200 |
| 87 | + }; |
| 88 | + |
| 89 | + if constexpr (std::is_same_v<Key, int32_t>) { |
| 90 | + return test_settings[0]; // int32 |
| 91 | + } else if constexpr (std::is_same_v<Key, int64_t>) { |
| 92 | + return test_settings[1]; // int64 |
| 93 | + } else if constexpr (std::is_same_v<Key, float>) { |
| 94 | + return test_settings[2]; // float |
| 95 | + } else { |
| 96 | + throw std::invalid_argument("Test settings available for int32, int64, float only.\n\n"); |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +template <typename Key> |
| 101 | +std::vector<Key> random_values(size_t size) |
| 102 | +{ |
| 103 | + std::vector<Key> values(size); |
| 104 | + |
| 105 | + using uniform_distribution = |
| 106 | + typename std::conditional_t<std::is_same_v<Key, bool>, |
| 107 | + std::bernoulli_distribution, |
| 108 | + std::conditional_t<std::is_floating_point_v<Key>, |
| 109 | + std::uniform_real_distribution<Key>, |
| 110 | + std::uniform_int_distribution<Key>>>; |
| 111 | + |
| 112 | + static constexpr auto seed = 0xf00d; |
| 113 | + static std::mt19937 engine{seed}; |
| 114 | + static uniform_distribution dist{}; |
| 115 | + std::generate_n(values.begin(), size, [&]() { return Key{dist(engine)}; }); |
| 116 | + |
| 117 | + return values; |
| 118 | +} |
| 119 | + |
| 120 | +} // namespace |
| 121 | + |
| 122 | +template <typename Filter> |
| 123 | +void test_filter_bitset(Filter& filter, size_t num_keys) |
| 124 | +{ |
| 125 | + using key_type = typename Filter::key_type; |
| 126 | + using word_type = typename Filter::word_type; |
| 127 | + |
| 128 | + // Generate keys |
| 129 | + auto const h_keys = random_values<key_type>(num_keys); |
| 130 | + thrust::device_vector<key_type> d_keys(h_keys.begin(), h_keys.end()); |
| 131 | + |
| 132 | + // Insert to the bloom filter |
| 133 | + filter.add(d_keys.begin(), d_keys.begin() + num_keys); |
| 134 | + |
| 135 | + // Get reference words device_vector |
| 136 | + auto const reference_words = get_arrow_filter_reference_bitset<key_type>(); |
| 137 | + |
| 138 | + // Number of words in the filter |
| 139 | + auto const num_words = filter.block_extent() * filter.words_per_block; |
| 140 | + |
| 141 | + // Get the bitset |
| 142 | + thrust::device_vector<word_type> filter_words(filter.data(), filter.data() + num_words); |
| 143 | + |
| 144 | + REQUIRE(cuco::test::equal( |
| 145 | + filter_words.begin(), |
| 146 | + filter_words.end(), |
| 147 | + reference_words.begin(), |
| 148 | + cuda::proclaim_return_type<bool>([] __device__(auto const& filter_word, auto const& ref_word) { |
| 149 | + return filter_word == ref_word; |
| 150 | + }))); |
| 151 | +} |
| 152 | + |
| 153 | +TEMPLATE_TEST_CASE_SIG( |
| 154 | + "Arrow filter policy bitset validation", "", (class Key), (int32_t), (int64_t), (float)) |
| 155 | +{ |
| 156 | + // Get test settings |
| 157 | + auto const [sub_filters, num_keys] = get_arrow_filter_test_settings<Key>(); |
| 158 | + |
| 159 | + using policy_type = cuco::arrow_filter_policy<Key>; |
| 160 | + cuco::bloom_filter<Key, cuco::extent<size_t>, cuda::thread_scope_device, policy_type> filter{ |
| 161 | + sub_filters}; |
| 162 | + |
| 163 | + test_filter_bitset(filter, num_keys); |
| 164 | +} |
0 commit comments