|
| 1 | +//===-- llvm/ADT/RadixTree.h - Radix Tree implementation --------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +//===----------------------------------------------------------------------===// |
| 7 | +// |
| 8 | +// This file implements a Radix Tree. |
| 9 | +// |
| 10 | +//===----------------------------------------------------------------------===// |
| 11 | + |
| 12 | +#ifndef LLVM_ADT_RADIXTREE_H |
| 13 | +#define LLVM_ADT_RADIXTREE_H |
| 14 | + |
| 15 | +#include "llvm/ADT/ADL.h" |
| 16 | +#include "llvm/ADT/STLExtras.h" |
| 17 | +#include "llvm/ADT/iterator.h" |
| 18 | +#include "llvm/ADT/iterator_range.h" |
| 19 | +#include <cassert> |
| 20 | +#include <cstddef> |
| 21 | +#include <iterator> |
| 22 | +#include <limits> |
| 23 | +#include <list> |
| 24 | +#include <utility> |
| 25 | + |
| 26 | +namespace llvm { |
| 27 | + |
| 28 | +/// \brief A Radix Tree implementation. |
| 29 | +/// |
| 30 | +/// A Radix Tree (also known as a compact prefix tree or radix trie) is a |
| 31 | +/// data structure that stores a dynamic set or associative array where keys |
| 32 | +/// are strings and values are associated with these keys. Unlike a regular |
| 33 | +/// trie, the edges of a radix tree can be labeled with sequences of characters |
| 34 | +/// as well as single characters. This makes radix trees more efficient for |
| 35 | +/// storing sparse data sets, where many nodes in a regular trie would have |
| 36 | +/// only one child. |
| 37 | +/// |
| 38 | +/// This implementation supports arbitrary key types that can be iterated over |
| 39 | +/// (e.g., `std::string`, `std::vector<char>`, `ArrayRef<char>`). The key type |
| 40 | +/// must provide `begin()` and `end()` for iteration. |
| 41 | +/// |
| 42 | +/// The tree stores `std::pair<const KeyType, T>` as its value type. |
| 43 | +/// |
| 44 | +/// Example usage: |
| 45 | +/// \code |
| 46 | +/// llvm::RadixTree<StringRef, int> Tree; |
| 47 | +/// Tree.emplace("apple", 1); |
| 48 | +/// Tree.emplace("grapefruit", 2); |
| 49 | +/// Tree.emplace("grape", 3); |
| 50 | +/// |
| 51 | +/// // Find prefixes |
| 52 | +/// for (const auto &[Key, Value] : Tree.find_prefixes("grapefruit juice")) { |
| 53 | +/// // pair will be {"grape", 3} |
| 54 | +/// // pair will be {"grapefruit", 2} |
| 55 | +/// llvm::outs() << Key << ": " << Value << "\n"; |
| 56 | +/// } |
| 57 | +/// |
| 58 | +/// // Iterate over all elements |
| 59 | +/// for (const auto &[Key, Value] : Tree) |
| 60 | +/// llvm::outs() << Key << ": " << Value << "\n"; |
| 61 | +/// \endcode |
| 62 | +/// |
| 63 | +/// \note |
| 64 | +/// The `RadixTree` takes ownership of the `KeyType` and `T` objects |
| 65 | +/// inserted into it. When an element is removed or the tree is destroyed, |
| 66 | +/// these objects will be destructed. |
| 67 | +/// However, if `KeyType` is a reference-like type, e.g., StringRef or range, |
| 68 | +/// the user must guarantee that the referenced data has a lifetime longer than |
| 69 | +/// the tree. |
| 70 | +template <typename KeyType, typename T> class RadixTree { |
| 71 | +public: |
| 72 | + using key_type = KeyType; |
| 73 | + using mapped_type = T; |
| 74 | + using value_type = std::pair<const KeyType, mapped_type>; |
| 75 | + |
| 76 | +private: |
| 77 | + using KeyConstIteratorType = |
| 78 | + decltype(adl_begin(std::declval<const key_type &>())); |
| 79 | + using KeyConstIteratorRangeType = iterator_range<KeyConstIteratorType>; |
| 80 | + using KeyValueType = |
| 81 | + remove_cvref_t<decltype(*adl_begin(std::declval<key_type &>()))>; |
| 82 | + using ContainerType = std::list<value_type>; |
| 83 | + |
| 84 | + /// Represents an internal node in the Radix Tree. |
| 85 | + struct Node { |
| 86 | + KeyConstIteratorRangeType Key{KeyConstIteratorType{}, |
| 87 | + KeyConstIteratorType{}}; |
| 88 | + std::vector<Node> Children; |
| 89 | + |
| 90 | + /// An iterator to the value associated with this node. |
| 91 | + /// |
| 92 | + /// If this node does not have a value (i.e., it's an internal node that |
| 93 | + /// only serves as a path to other values), this iterator will be equal |
| 94 | + /// to default constructed `ContainerType::iterator()`. |
| 95 | + typename ContainerType::iterator Value; |
| 96 | + |
| 97 | + /// The first character of the Key. Used for fast child lookup. |
| 98 | + KeyValueType KeyFront; |
| 99 | + |
| 100 | + Node() = default; |
| 101 | + Node(const KeyConstIteratorRangeType &Key) |
| 102 | + : Key(Key), KeyFront(*Key.begin()) { |
| 103 | + assert(!Key.empty()); |
| 104 | + } |
| 105 | + |
| 106 | + Node(Node &&) = default; |
| 107 | + Node &operator=(Node &&) = default; |
| 108 | + |
| 109 | + Node(const Node &) = delete; |
| 110 | + Node &operator=(const Node &) = delete; |
| 111 | + |
| 112 | + const Node *findChild(const KeyConstIteratorRangeType &Key) const { |
| 113 | + if (Key.empty()) |
| 114 | + return nullptr; |
| 115 | + for (const Node &Child : Children) { |
| 116 | + assert(!Child.Key.empty()); // Only root can be empty. |
| 117 | + if (Child.KeyFront == *Key.begin()) |
| 118 | + return &Child; |
| 119 | + } |
| 120 | + return nullptr; |
| 121 | + } |
| 122 | + |
| 123 | + Node *findChild(const KeyConstIteratorRangeType &Query) { |
| 124 | + const Node *This = this; |
| 125 | + return const_cast<Node *>(This->findChild(Query)); |
| 126 | + } |
| 127 | + |
| 128 | + size_t countNodes() const { |
| 129 | + size_t R = 1; |
| 130 | + for (const Node &C : Children) |
| 131 | + R += C.countNodes(); |
| 132 | + return R; |
| 133 | + } |
| 134 | + |
| 135 | + /// |
| 136 | + /// Splits the current node into two. |
| 137 | + /// |
| 138 | + /// This function is used when a new key needs to be inserted that shares |
| 139 | + /// a common prefix with the current node's key, but then diverges. |
| 140 | + /// The current `Key` is truncated to the common prefix, and a new child |
| 141 | + /// node is created for the remainder of the original node's `Key`. |
| 142 | + /// |
| 143 | + /// \param SplitPoint An iterator pointing to the character in the current |
| 144 | + /// `Key` where the split should occur. |
| 145 | + void split(KeyConstIteratorType SplitPoint) { |
| 146 | + Node Child(make_range(SplitPoint, Key.end())); |
| 147 | + Key = make_range(Key.begin(), SplitPoint); |
| 148 | + |
| 149 | + Children.swap(Child.Children); |
| 150 | + std::swap(Value, Child.Value); |
| 151 | + |
| 152 | + Children.emplace_back(std::move(Child)); |
| 153 | + } |
| 154 | + }; |
| 155 | + |
| 156 | + /// Root always corresponds to the empty key, which is the shortest possible |
| 157 | + /// prefix for everything. |
| 158 | + Node Root; |
| 159 | + ContainerType KeyValuePairs; |
| 160 | + |
| 161 | + /// Finds or creates a new tail or leaf node corresponding to the `Key`. |
| 162 | + Node &findOrCreate(KeyConstIteratorRangeType Key) { |
| 163 | + Node *Curr = &Root; |
| 164 | + if (Key.empty()) |
| 165 | + return *Curr; |
| 166 | + |
| 167 | + for (;;) { |
| 168 | + auto [I1, I2] = llvm::mismatch(Key, Curr->Key); |
| 169 | + Key = make_range(I1, Key.end()); |
| 170 | + |
| 171 | + if (I2 != Curr->Key.end()) { |
| 172 | + // Match is partial. Either query is too short, or there is mismatching |
| 173 | + // character. Split either way, and put new node in between of the |
| 174 | + // current and its children. |
| 175 | + Curr->split(I2); |
| 176 | + |
| 177 | + // Split was caused by mismatch, so `findChild` would fail. |
| 178 | + break; |
| 179 | + } |
| 180 | + |
| 181 | + Node *Child = Curr->findChild(Key); |
| 182 | + if (!Child) |
| 183 | + break; |
| 184 | + |
| 185 | + // Move to child with the same first character. |
| 186 | + Curr = Child; |
| 187 | + } |
| 188 | + |
| 189 | + if (Key.empty()) { |
| 190 | + // The current node completely matches the key, return it. |
| 191 | + return *Curr; |
| 192 | + } |
| 193 | + |
| 194 | + // `Key` is a suffix of original `Key` unmatched by path from the `Root` to |
| 195 | + // the `Curr`, and we have no candidate in the children to match more. |
| 196 | + // Create a new one. |
| 197 | + return Curr->Children.emplace_back(Key); |
| 198 | + } |
| 199 | + |
| 200 | + /// |
| 201 | + /// An iterator for traversing prefixes search results. |
| 202 | + /// |
| 203 | + /// This iterator is used by `find_prefixes` to traverse the tree and find |
| 204 | + /// elements that are prefixes to the given key. It's a forward iterator. |
| 205 | + /// |
| 206 | + /// \tparam MappedType The type of the value pointed to by the iterator. |
| 207 | + /// This will be `value_type` for non-const iterators |
| 208 | + /// and `const value_type` for const iterators. |
| 209 | + template <typename MappedType> |
| 210 | + class IteratorImpl |
| 211 | + : public iterator_facade_base<IteratorImpl<MappedType>, |
| 212 | + std::forward_iterator_tag, MappedType> { |
| 213 | + const Node *Curr = nullptr; |
| 214 | + KeyConstIteratorRangeType Query{KeyConstIteratorType{}, |
| 215 | + KeyConstIteratorType{}}; |
| 216 | + |
| 217 | + void findNextValid() { |
| 218 | + while (Curr && Curr->Value == typename ContainerType::iterator()) |
| 219 | + advance(); |
| 220 | + } |
| 221 | + |
| 222 | + void advance() { |
| 223 | + assert(Curr); |
| 224 | + if (Query.empty()) { |
| 225 | + Curr = nullptr; |
| 226 | + return; |
| 227 | + } |
| 228 | + |
| 229 | + Curr = Curr->findChild(Query); |
| 230 | + if (!Curr) { |
| 231 | + Curr = nullptr; |
| 232 | + return; |
| 233 | + } |
| 234 | + |
| 235 | + auto [I1, I2] = llvm::mismatch(Query, Curr->Key); |
| 236 | + if (I2 != Curr->Key.end()) { |
| 237 | + Curr = nullptr; |
| 238 | + return; |
| 239 | + } |
| 240 | + Query = make_range(I1, Query.end()); |
| 241 | + } |
| 242 | + |
| 243 | + friend class RadixTree; |
| 244 | + IteratorImpl(const Node *C, const KeyConstIteratorRangeType &Q) |
| 245 | + : Curr(C), Query(Q) { |
| 246 | + findNextValid(); |
| 247 | + } |
| 248 | + |
| 249 | + public: |
| 250 | + IteratorImpl() = default; |
| 251 | + |
| 252 | + MappedType &operator*() const { return *Curr->Value; } |
| 253 | + |
| 254 | + IteratorImpl &operator++() { |
| 255 | + advance(); |
| 256 | + findNextValid(); |
| 257 | + return *this; |
| 258 | + } |
| 259 | + |
| 260 | + bool operator==(const IteratorImpl &Other) const { |
| 261 | + return Curr == Other.Curr; |
| 262 | + } |
| 263 | + }; |
| 264 | + |
| 265 | +public: |
| 266 | + RadixTree() = default; |
| 267 | + RadixTree(RadixTree &&) = default; |
| 268 | + RadixTree &operator=(RadixTree &&) = default; |
| 269 | + |
| 270 | + using prefix_iterator = IteratorImpl<value_type>; |
| 271 | + using const_prefix_iterator = IteratorImpl<const value_type>; |
| 272 | + |
| 273 | + using iterator = typename ContainerType::iterator; |
| 274 | + using const_iterator = typename ContainerType::const_iterator; |
| 275 | + |
| 276 | + /// Returns true if the tree is empty. |
| 277 | + bool empty() const { return KeyValuePairs.empty(); } |
| 278 | + |
| 279 | + /// Returns the number of elements in the tree. |
| 280 | + size_t size() const { return KeyValuePairs.size(); } |
| 281 | + |
| 282 | + /// Returns the number of nodes in the tree. |
| 283 | + /// |
| 284 | + /// This function counts all internal nodes in the tree. It can be useful for |
| 285 | + /// understanding the memory footprint or complexity of the tree structure. |
| 286 | + size_t countNodes() const { return Root.countNodes(); } |
| 287 | + |
| 288 | + /// Returns an iterator to the first element. |
| 289 | + iterator begin() { return KeyValuePairs.begin(); } |
| 290 | + const_iterator begin() const { return KeyValuePairs.begin(); } |
| 291 | + |
| 292 | + /// Returns an iterator to the end of the tree. |
| 293 | + iterator end() { return KeyValuePairs.end(); } |
| 294 | + const_iterator end() const { return KeyValuePairs.end(); } |
| 295 | + |
| 296 | + /// Constructs and inserts a new element into the tree. |
| 297 | + /// |
| 298 | + /// This function constructs an element in place within the tree. If an |
| 299 | + /// element with the same key already exists, the insertion fails and the |
| 300 | + /// function returns an iterator to the existing element along with `false`. |
| 301 | + /// Otherwise, the new element is inserted and the function returns an |
| 302 | + /// iterator to the new element along with `true`. |
| 303 | + /// |
| 304 | + /// \param Key The key of the element to construct. |
| 305 | + /// \param Args Arguments to forward to the constructor of the mapped_type. |
| 306 | + /// \return A pair consisting of an iterator to the inserted element (or to |
| 307 | + /// the element that prevented insertion) and a boolean value |
| 308 | + /// indicating whether the insertion took place. |
| 309 | + template <typename... Ts> |
| 310 | + std::pair<iterator, bool> emplace(key_type &&Key, Ts &&...Args) { |
| 311 | + // We want to make new `Node` to refer key in the container, not the one |
| 312 | + // from the argument. |
| 313 | + // FIXME: Determine that we need a new node, before expanding |
| 314 | + // `KeyValuePairs`. |
| 315 | + const value_type &NewValue = KeyValuePairs.emplace_front( |
| 316 | + std::move(Key), T(std::forward<Ts>(Args)...)); |
| 317 | + Node &Node = findOrCreate(NewValue.first); |
| 318 | + bool HasValue = Node.Value != typename ContainerType::iterator(); |
| 319 | + if (!HasValue) |
| 320 | + Node.Value = KeyValuePairs.begin(); |
| 321 | + else |
| 322 | + KeyValuePairs.pop_front(); |
| 323 | + return {Node.Value, !HasValue}; |
| 324 | + } |
| 325 | + |
| 326 | + /// |
| 327 | + /// Finds all elements whose keys are prefixes of the given `Key`. |
| 328 | + /// |
| 329 | + /// This function returns an iterator range over all elements in the tree |
| 330 | + /// whose keys are prefixes of the provided `Key`. For example, if the tree |
| 331 | + /// contains "abcde", "abc", "abcdefgh", and `Key` is "abcde", this function |
| 332 | + /// would return iterators to "abcde" and "abc". |
| 333 | + /// |
| 334 | + /// \param Key The key to search for prefixes of. |
| 335 | + /// \return An `iterator_range` of `const_prefix_iterator`s, allowing |
| 336 | + /// iteration over the found prefix elements. |
| 337 | + /// \note The returned iterators reference the `Key` provided by the caller. |
| 338 | + /// The caller must ensure that `Key` remains valid for the lifetime |
| 339 | + /// of the iterators. |
| 340 | + iterator_range<const_prefix_iterator> |
| 341 | + find_prefixes(const key_type &Key) const { |
| 342 | + return iterator_range<const_prefix_iterator>{ |
| 343 | + const_prefix_iterator(&Root, KeyConstIteratorRangeType(Key)), |
| 344 | + const_prefix_iterator{}}; |
| 345 | + } |
| 346 | +}; |
| 347 | + |
| 348 | +} // namespace llvm |
| 349 | + |
| 350 | +#endif // LLVM_ADT_RADIXTREE_H |
0 commit comments