diff --git a/llvm/include/llvm/ADT/TrieHashIndexGenerator.h b/llvm/include/llvm/ADT/TrieHashIndexGenerator.h new file mode 100644 index 0000000000000..6f7e53b6b11b5 --- /dev/null +++ b/llvm/include/llvm/ADT/TrieHashIndexGenerator.h @@ -0,0 +1,122 @@ +//===- TrieHashIndexGenerator.h ---------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_TRIEHASHINDEXGENERATOR_H +#define LLVM_ADT_TRIEHASHINDEXGENERATOR_H + +#include "llvm/ADT/ArrayRef.h" +#include + +namespace llvm { + +/// The utility class that helps computing the index of the object inside trie +/// from its hash. The generator can be configured with the number of bits +/// used for each level of trie structure with \c NumRootsBits and \c +/// NumSubtrieBits. +/// For example, try computing indexes for a 16-bit hash 0x1234 with 8-bit root +/// and 4-bit sub-trie: +/// +/// IndexGenerator IndexGen{8, 4, Hash}; +/// size_t index1 = IndexGen.next(); // index 18 in root node. +/// size_t index2 = IndexGen.next(); // index 3 in sub-trie level 1. +/// size_t index3 = IndexGen.next(); // index 4 in sub-tire level 2. +/// +/// This is used by different trie implementation to figure out where to +/// insert/find the object in the data structure. +struct TrieHashIndexGenerator { + size_t NumRootBits; + size_t NumSubtrieBits; + ArrayRef Bytes; + std::optional StartBit = std::nullopt; + + // Get the number of bits used to generate current index. + size_t getNumBits() const { + assert(StartBit); + size_t TotalNumBits = Bytes.size() * 8; + assert(*StartBit <= TotalNumBits); + return std::min(*StartBit ? NumSubtrieBits : NumRootBits, + TotalNumBits - *StartBit); + } + + // Get the index of the object in the next level of trie. + size_t next() { + if (!StartBit) { + // Compute index for root when StartBit is not set. + StartBit = 0; + return getIndex(Bytes, *StartBit, NumRootBits); + } + if (*StartBit < Bytes.size() * 8) { + // Compute index for sub-trie. + *StartBit += *StartBit ? NumSubtrieBits : NumRootBits; + assert((*StartBit - NumRootBits) % NumSubtrieBits == 0); + return getIndex(Bytes, *StartBit, NumSubtrieBits); + } + // All the bits are consumed. + return end(); + } + + // Provide a hint to speed up the index generation by providing the + // information of the hash in current level. For example, if the object is + // known to have \c Index on a level that already consumes first n \c Bits of + // the hash, it can start index generation from this level by calling \c hint + // function. + size_t hint(unsigned Index, unsigned Bit) { + assert(Bit < Bytes.size() * 8); + assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0); + StartBit = Bit; + return Index; + } + + // Utility function for looking up the index in the trie for an object that + // has colliding hash bits in the front as the hash of the object that is + // currently being computed. + size_t getCollidingBits(ArrayRef CollidingBits) const { + assert(StartBit); + return getIndex(CollidingBits, *StartBit, NumSubtrieBits); + } + + size_t end() const { return SIZE_MAX; } + + // Compute the index for the object from its hash, current start bits, and + // the number of bits used for current level. + static size_t getIndex(ArrayRef Bytes, size_t StartBit, + size_t NumBits) { + assert(StartBit < Bytes.size() * 8); + // Drop all the bits before StartBit. + Bytes = Bytes.drop_front(StartBit / 8u); + StartBit %= 8u; + size_t Index = 0; + // Compute the index using the bits in range [StartBit, StartBit + NumBits), + // note the range can spread across few `uint8_t` in the array. + for (uint8_t Byte : Bytes) { + size_t ByteStart = 0, ByteEnd = 8; + if (StartBit) { + ByteStart = StartBit; + Byte &= (1u << (8 - StartBit)) - 1u; + StartBit = 0; + } + size_t CurrentNumBits = ByteEnd - ByteStart; + if (CurrentNumBits > NumBits) { + Byte >>= CurrentNumBits - NumBits; + CurrentNumBits = NumBits; + } + Index <<= CurrentNumBits; + Index |= Byte & ((1u << CurrentNumBits) - 1u); + + assert(NumBits >= CurrentNumBits); + NumBits -= CurrentNumBits; + if (!NumBits) + break; + } + return Index; + } +}; + +} // namespace llvm + +#endif // LLVM_ADT_TRIEHASHINDEXGENERATOR_H diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h new file mode 100644 index 0000000000000..5bfe5c9e6a0f4 --- /dev/null +++ b/llvm/include/llvm/ADT/TrieRawHashMap.h @@ -0,0 +1,377 @@ +//===- TrieRawHashMap.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_TRIERAWHASHMAP_H +#define LLVM_ADT_TRIERAWHASHMAP_H + +#include "llvm/ADT/ArrayRef.h" +#include +#include + +namespace llvm { + +class raw_ostream; + +/// TrieRawHashMap - is a lock-free thread-safe trie that is can be used to +/// store/index data based on a hash value. It can be customized to work with +/// any hash algorithm or store any data. +/// +/// Data structure: +/// Data node stored in the Trie contains both hash and data: +/// struct { +/// HashT Hash; +/// DataT Data; +/// }; +/// +/// Data is stored/indexed via a prefix tree, where each node in the tree can be +/// either the root, a sub-trie or a data node. Assuming a 4-bit hash and two +/// data objects {0001, A} and {0100, B}, it can be stored in a trie +/// (assuming Root has 2 bits, SubTrie has 1 bit): +/// +--------+ +/// |Root[00]| -> {0001, A} +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// +/// Inserting a new object {0010, C} will result in: +/// +--------+ +----------+ +/// |Root[00]| -> |SubTrie[0]| -> {0001, A} +/// | | | [1]| -> {0010, C} +/// | | +----------+ +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// Note object A is sunk down to a sub-trie during the insertion. All the +/// nodes are inserted through compare-exchange to ensure thread-safe and +/// lock-free. +/// +/// To find an object in the trie, walk the tree with prefix of the hash until +/// the data node is found. Then the hash is compared with the hash stored in +/// the data node to see if the is the same object. +/// +/// Hash collision is not allowed so it is recommended to use trie with a +/// "strong" hashing algorithm. A well-distributed hash can also result in +/// better performance and memory usage. +/// +/// It currently does not support iteration and deletion. + +/// Base class for a lock-free thread-safe hash-mapped trie. +class ThreadSafeTrieRawHashMapBase { +public: + static constexpr size_t TrieContentBaseSize = 4; + static constexpr size_t DefaultNumRootBits = 6; + static constexpr size_t DefaultNumSubtrieBits = 4; + +private: + template struct AllocValueType { + char Base[TrieContentBaseSize]; + std::aligned_union_t Content; + }; + +protected: + template + static constexpr size_t DefaultContentAllocSize = sizeof(AllocValueType); + + template + static constexpr size_t DefaultContentAllocAlign = alignof(AllocValueType); + + template + static constexpr size_t DefaultContentOffset = + offsetof(AllocValueType, Content); + +public: + static void *operator new(size_t Size) { return ::operator new(Size); } + void operator delete(void *Ptr) { ::operator delete(Ptr); } + + LLVM_DUMP_METHOD void dump() const; + void print(raw_ostream &OS) const; + +protected: + /// Result of a lookup. Suitable for an insertion hint. Maybe could be + /// expanded into an iterator of sorts, but likely not useful (visiting + /// everything in the trie should probably be done some way other than + /// through an iterator pattern). + class PointerBase { + protected: + void *get() const { return I == -2u ? P : nullptr; } + + public: + PointerBase() noexcept = default; + + private: + friend class ThreadSafeTrieRawHashMapBase; + explicit PointerBase(void *Content) : P(Content), I(-2u) {} + PointerBase(void *P, unsigned I, unsigned B) : P(P), I(I), B(B) {} + + bool isHint() const { return I != -1u && I != -2u; } + + void *P = nullptr; + unsigned I = -1u; + unsigned B = 0; + }; + + /// Find the stored content with hash. + PointerBase find(ArrayRef Hash) const; + + /// Insert and return the stored content. + PointerBase + insert(PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor); + + ThreadSafeTrieRawHashMapBase() = delete; + + ThreadSafeTrieRawHashMapBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt); + + /// Destructor, which asserts if there's anything to do. Subclasses should + /// call \a destroyImpl(). + /// + /// \pre \a destroyImpl() was already called. + ~ThreadSafeTrieRawHashMapBase(); + void destroyImpl(function_ref Destructor); + + ThreadSafeTrieRawHashMapBase(ThreadSafeTrieRawHashMapBase &&RHS); + + // Move assignment is not supported as it is not thread-safe. + ThreadSafeTrieRawHashMapBase & + operator=(ThreadSafeTrieRawHashMapBase &&RHS) = delete; + + // No copy. + ThreadSafeTrieRawHashMapBase(const ThreadSafeTrieRawHashMapBase &) = delete; + ThreadSafeTrieRawHashMapBase & + operator=(const ThreadSafeTrieRawHashMapBase &) = delete; + + // Debug functions. Implementation details and not guaranteed to be + // thread-safe. + PointerBase getRoot() const; + unsigned getStartBit(PointerBase P) const; + unsigned getNumBits(PointerBase P) const; + unsigned getNumSlotUsed(PointerBase P) const; + std::string getTriePrefixAsString(PointerBase P) const; + unsigned getNumTries() const; + // Visit next trie in the allocation chain. + PointerBase getNextTrie(PointerBase P) const; + +private: + friend class TrieRawHashMapTestHelper; + const unsigned short ContentAllocSize; + const unsigned short ContentAllocAlign; + const unsigned short ContentOffset; + unsigned short NumRootBits; + unsigned short NumSubtrieBits; + class ImplType; + // ImplPtr is owned by ThreadSafeTrieRawHashMapBase and needs to be freed in + // destroyImpl. + std::atomic ImplPtr; + ImplType &getOrCreateImpl(); + ImplType *getImpl() const; +}; + +/// Lock-free thread-safe hash-mapped trie. +template +class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase { +public: + using HashT = std::array; + + class LazyValueConstructor; + struct value_type { + const HashT Hash; + T Data; + + value_type(value_type &&) = default; + value_type(const value_type &) = default; + + value_type(ArrayRef Hash, const T &Data) + : Hash(makeHash(Hash)), Data(Data) {} + value_type(ArrayRef Hash, T &&Data) + : Hash(makeHash(Hash)), Data(std::move(Data)) {} + + private: + friend class LazyValueConstructor; + + struct EmplaceTag {}; + template + value_type(ArrayRef Hash, EmplaceTag, ArgsT &&...Args) + : Hash(makeHash(Hash)), Data(std::forward(Args)...) {} + + static HashT makeHash(ArrayRef HashRef) { + HashT Hash; + std::copy(HashRef.begin(), HashRef.end(), Hash.data()); + return Hash; + } + }; + + using ThreadSafeTrieRawHashMapBase::operator delete; + using HashType = HashT; + + using ThreadSafeTrieRawHashMapBase::dump; + using ThreadSafeTrieRawHashMapBase::print; + +private: + template class PointerImpl : PointerBase { + friend class ThreadSafeTrieRawHashMap; + + ValueT *get() const { + return reinterpret_cast(PointerBase::get()); + } + + public: + ValueT &operator*() const { + assert(get()); + return *get(); + } + ValueT *operator->() const { + assert(get()); + return get(); + } + explicit operator bool() const { return get(); } + + PointerImpl() = default; + + protected: + PointerImpl(PointerBase Result) : PointerBase(Result) {} + }; + +public: + class pointer; + class const_pointer; + class pointer : public PointerImpl { + friend class ThreadSafeTrieRawHashMap; + friend class const_pointer; + + public: + pointer() = default; + + private: + pointer(PointerBase Result) : pointer::PointerImpl(Result) {} + }; + + class const_pointer : public PointerImpl { + friend class ThreadSafeTrieRawHashMap; + + public: + const_pointer() = default; + const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {} + + private: + const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {} + }; + + class LazyValueConstructor { + public: + value_type &operator()(T &&RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, std::move(RHS))); + } + value_type &operator()(const T &RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, RHS)); + } + template value_type &emplace(ArgsT &&...Args) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) + value_type(Hash, typename value_type::EmplaceTag{}, + std::forward(Args)...)); + } + + LazyValueConstructor(LazyValueConstructor &&RHS) + : Mem(RHS.Mem), Result(RHS.Result), Hash(RHS.Hash) { + RHS.Mem = nullptr; // Moved away, cannot call. + } + ~LazyValueConstructor() { assert(!Mem && "Constructor never called!"); } + + private: + value_type &assign(value_type *V) { + Mem = nullptr; + Result = V; + return *V; + } + friend class ThreadSafeTrieRawHashMap; + LazyValueConstructor() = delete; + LazyValueConstructor(void *Mem, value_type *&Result, ArrayRef Hash) + : Mem(Mem), Result(Result), Hash(Hash) { + assert(Hash.size() == sizeof(HashT) && "Invalid hash"); + assert(Mem && "Invalid memory for construction"); + } + void *Mem; + value_type *&Result; + ArrayRef Hash; + }; + + /// Insert with a hint. Default-constructed hint will work, but it's + /// recommended to start with a lookup to avoid overhead in object creation + /// if it already exists. + pointer insertLazy(const_pointer Hint, ArrayRef Hash, + function_ref OnConstruct) { + return pointer(ThreadSafeTrieRawHashMapBase::insert( + Hint, Hash, [&](void *Mem, ArrayRef Hash) { + value_type *Result = nullptr; + OnConstruct(LazyValueConstructor(Mem, Result, Hash)); + return Result->Hash.data(); + })); + } + + pointer insertLazy(ArrayRef Hash, + function_ref OnConstruct) { + return insertLazy(const_pointer(), Hash, OnConstruct); + } + + pointer insert(const_pointer Hint, value_type &&HashedData) { + return insertLazy(Hint, HashedData.Hash, [&](LazyValueConstructor C) { + C(std::move(HashedData.Data)); + }); + } + + pointer insert(const_pointer Hint, const value_type &HashedData) { + return insertLazy(Hint, HashedData.Hash, + [&](LazyValueConstructor C) { C(HashedData.Data); }); + } + + pointer find(ArrayRef Hash) { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeTrieRawHashMapBase::find(Hash); + } + + const_pointer find(ArrayRef Hash) const { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeTrieRawHashMapBase::find(Hash); + } + + ThreadSafeTrieRawHashMap(std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt) + : ThreadSafeTrieRawHashMapBase(DefaultContentAllocSize, + DefaultContentAllocAlign, + DefaultContentOffset, + NumRootBits, NumSubtrieBits) {} + + ~ThreadSafeTrieRawHashMap() { + if constexpr (std::is_trivially_destructible::value) + this->destroyImpl(nullptr); + else + this->destroyImpl( + [](void *P) { static_cast(P)->~value_type(); }); + } + + // Move constructor okay. + ThreadSafeTrieRawHashMap(ThreadSafeTrieRawHashMap &&) = default; + + // No move assignment or any copy. + ThreadSafeTrieRawHashMap &operator=(ThreadSafeTrieRawHashMap &&) = delete; + ThreadSafeTrieRawHashMap(const ThreadSafeTrieRawHashMap &) = delete; + ThreadSafeTrieRawHashMap & + operator=(const ThreadSafeTrieRawHashMap &) = delete; +}; + +} // namespace llvm + +#endif // LLVM_ADT_TRIERAWHASHMAP_H diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index 531bdeaca1261..2ecaea4b02bf6 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -256,6 +256,7 @@ add_llvm_component_library(LLVMSupport TimeProfiler.cpp Timer.cpp ToolOutputFile.cpp + TrieRawHashMap.cpp Twine.cpp TypeSize.cpp Unicode.cpp diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp new file mode 100644 index 0000000000000..9eeac0bbc5c2c --- /dev/null +++ b/llvm/lib/Support/TrieRawHashMap.cpp @@ -0,0 +1,515 @@ +//===- TrieRawHashMap.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/TrieRawHashMap.h" +#include "llvm/ADT/LazyAtomicPointer.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TrieHashIndexGenerator.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ThreadSafeAllocator.h" +#include "llvm/Support/TrailingObjects.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; + +namespace { +struct TrieNode { + const bool IsSubtrie = false; + + TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {} + + static void *operator new(size_t Size) { return ::operator new(Size); } + void operator delete(void *Ptr) { ::operator delete(Ptr); } +}; + +struct TrieContent final : public TrieNode { + const uint8_t ContentOffset; + const uint8_t HashSize; + const uint8_t HashOffset; + + void *getValuePointer() const { + auto *Content = reinterpret_cast(this) + ContentOffset; + return const_cast(Content); + } + + ArrayRef getHash() const { + auto *Begin = reinterpret_cast(this) + HashOffset; + return ArrayRef(Begin, Begin + HashSize); + } + + TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset) + : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset), + HashSize(HashSize), HashOffset(HashOffset) {} + + static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; } +}; + +static_assert(sizeof(TrieContent) == + ThreadSafeTrieRawHashMapBase::TrieContentBaseSize, + "Check header assumption!"); + +class TrieSubtrie final + : public TrieNode, + private TrailingObjects> { +public: + using Slot = LazyAtomicPointer; + + Slot &get(size_t I) { return getTrailingObjects()[I]; } + TrieNode *load(size_t I) { return get(I).load(); } + + unsigned size() const { return Size; } + + TrieSubtrie * + sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver); + + static std::unique_ptr create(size_t StartBit, size_t NumBits); + + explicit TrieSubtrie(size_t StartBit, size_t NumBits); + + static bool classof(const TrieNode *TN) { return TN->IsSubtrie; } + + static constexpr size_t sizeToAlloc(unsigned NumBits) { + assert(NumBits < 20 && "Tries should have fewer than ~1M slots"); + size_t Count = 1u << NumBits; + return totalSizeToAlloc>(Count); + } + +private: + // FIXME: Use a bitset to speed up access: + // + // std::array, NumSlots/64> IsSet; + // + // This will avoid needing to visit sparsely filled slots in + // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial + // destructor. + // + // It would also greatly speed up iteration, if we add that some day, and + // allow get() to return one level sooner. + // + // This would be the algorithm for updating IsSet (after updating Slots): + // + // std::atomic &Bits = IsSet[I.High]; + // const uint64_t NewBit = 1ULL << I.Low; + // uint64_t Old = 0; + // while (!Bits.compare_exchange_weak(Old, Old | NewBit)) + // ; + + // For debugging. + unsigned StartBit = 0; + unsigned NumBits = 0; + unsigned Size = 0; + friend class llvm::ThreadSafeTrieRawHashMapBase; + friend class TrailingObjects; + +public: + /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie. + std::atomic Next; +}; +} // end namespace + +std::unique_ptr TrieSubtrie::create(size_t StartBit, + size_t NumBits) { + void *Memory = ::operator new(sizeToAlloc(NumBits)); + TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits); + return std::unique_ptr(S); +} + +TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits) + : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Size(1u << NumBits), + Next(nullptr) { + for (unsigned I = 0; I < Size; ++I) + new (&get(I)) Slot(nullptr); + + static_assert( + std::is_trivially_destructible>::value, + "Expected no work in destructor for TrieNode"); +} + +// Sink the nodes down sub-trie when the object being inserted collides with +// the index of existing object in the trie. In this case, a new sub-trie needs +// to be allocated to hold existing object. +TrieSubtrie *TrieSubtrie::sink( + size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver) { + // Create a new sub-trie that points to the existing object with the new + // index for the next level. + assert(NumSubtrieBits > 0); + std::unique_ptr S = create(StartBit + NumBits, NumSubtrieBits); + + assert(NewI < Size); + S->get(NewI).store(&Content); + + // Using compare_exchange to atomically add back the new sub-trie to the trie + // in the place of the exsiting object. + TrieNode *ExistingNode = &Content; + assert(I < Size); + if (get(I).compare_exchange_strong(ExistingNode, S.get())) + return Saver(std::move(S)); + + // Another thread created a subtrie already. Return it and let "S" be + // destructed. + return cast(ExistingNode); +} + +class ThreadSafeTrieRawHashMapBase::ImplType final + : private TrailingObjects { +public: + static std::unique_ptr create(size_t StartBit, size_t NumBits) { + size_t Size = sizeof(ImplType) + TrieSubtrie::sizeToAlloc(NumBits); + void *Memory = ::operator new(Size); + ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits); + return std::unique_ptr(Impl); + } + + // Save the Subtrie into the ownship list of the trie structure in a + // thread-safe way. The ownership transfer is done by compare_exchange the + // pointer value inside the unique_ptr. + TrieSubtrie *save(std::unique_ptr S) { + assert(!S->Next && "Expected S to a freshly-constructed leaf"); + + TrieSubtrie *CurrentHead = nullptr; + // Add ownership of "S" to front of the list, so that Root -> S -> + // Root.Next. This works by repeatedly setting S->Next to a candidate value + // of Root.Next (initially nullptr), then setting Root.Next to S once the + // candidate matches reality. + while (!getRoot()->Next.compare_exchange_weak(CurrentHead, S.get())) + S->Next.exchange(CurrentHead); + + // Ownership transferred to subtrie successfully. Release the unique_ptr. + return S.release(); + } + + // Get the root which is the trailing object. + TrieSubtrie *getRoot() { return getTrailingObjects(); } + + static void *operator new(size_t Size) { return ::operator new(Size); } + void operator delete(void *Ptr) { ::operator delete(Ptr); } + + /// FIXME: This should take a function that allocates and constructs the + /// content lazily (taking the hash as a separate parameter), in case of + /// collision. + ThreadSafeAllocator ContentAlloc; + +private: + friend class TrailingObjects; + + ImplType(size_t StartBit, size_t NumBits) { + ::new (getRoot()) TrieSubtrie(StartBit, NumBits); + } +}; + +ThreadSafeTrieRawHashMapBase::ImplType & +ThreadSafeTrieRawHashMapBase::getOrCreateImpl() { + if (ImplType *Impl = ImplPtr.load()) + return *Impl; + + // Create a new ImplType and store it if another thread doesn't do so first. + // If another thread wins this one is destroyed locally. + std::unique_ptr Impl = ImplType::create(0, NumRootBits); + ImplType *ExistingImpl = nullptr; + + // If the ownership transferred succesfully, release unique_ptr and return + // the pointer to the new ImplType. + if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get())) + return *Impl.release(); + + // Already created, return the existing ImplType. + return *ExistingImpl; +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::find(ArrayRef Hash) const { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + + TrieSubtrie *S = Impl->getRoot(); + TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index = IndexGen.next(); + while (Index != IndexGen.end()) { + // Try to set the content. + TrieNode *Existing = S->get(Index); + if (!Existing) + return PointerBase(S, Index, *IndexGen.StartBit); + + // Check for an exact match. + if (auto *ExistingContent = dyn_cast(Existing)) + return ExistingContent->getHash() == Hash + ? PointerBase(ExistingContent->getValuePointer()) + : PointerBase(S, Index, *IndexGen.StartBit); + + Index = IndexGen.next(); + S = cast(Existing); + } + llvm_unreachable("failed to locate the node after consuming all hash bytes"); +} + +ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert( + PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor) { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType &Impl = getOrCreateImpl(); + TrieSubtrie *S = Impl.getRoot(); + TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index; + if (Hint.isHint()) { + S = static_cast(Hint.P); + Index = IndexGen.hint(Hint.I, Hint.B); + } else { + Index = IndexGen.next(); + } + + while (Index != IndexGen.end()) { + // Load the node from the slot, allocating and calling the constructor if + // the slot is empty. + bool Generated = false; + TrieNode &Existing = S->get(Index).loadOrGenerate([&]() { + Generated = true; + + // Construct the value itself at the tail. + uint8_t *Memory = reinterpret_cast( + Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign)); + const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash); + + // Construct the TrieContent header, passing in the offset to the hash. + TrieContent *Content = ::new (Memory) + TrieContent(ContentOffset, Hash.size(), HashStorage - Memory); + assert(Hash == Content->getHash() && "Hash not properly initialized"); + return Content; + }); + // If we just generated it, return it! + if (Generated) + return PointerBase(cast(Existing).getValuePointer()); + + if (auto *ST = dyn_cast(&Existing)) { + S = ST; + Index = IndexGen.next(); + continue; + } + + // Return the existing content if it's an exact match! + auto &ExistingContent = cast(Existing); + if (ExistingContent.getHash() == Hash) + return PointerBase(ExistingContent.getValuePointer()); + + // Sink the existing content as long as the indexes match. + size_t NextIndex = IndexGen.next(); + while (NextIndex != IndexGen.end()) { + size_t NewIndexForExistingContent = + IndexGen.getCollidingBits(ExistingContent.getHash()); + S = S->sink(Index, ExistingContent, IndexGen.getNumBits(), + NewIndexForExistingContent, + [&Impl](std::unique_ptr S) { + return Impl.save(std::move(S)); + }); + Index = NextIndex; + + // Found the difference. + if (NextIndex != NewIndexForExistingContent) + break; + + NextIndex = IndexGen.next(); + } + } + llvm_unreachable("failed to insert the node after consuming all hash bytes"); +} + +ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits, std::optional NumSubtrieBits) + : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign), + ContentOffset(ContentOffset), + NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits), + NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits), + ImplPtr(nullptr) { + // Assertion checks for reasonable configuration. The settings below are not + // hard limits on most platforms, but a reasonable configuration should fall + // within those limits. + assert((!NumRootBits || *NumRootBits < 20) && + "Root should have fewer than ~1M slots"); + assert((!NumSubtrieBits || *NumSubtrieBits < 10) && + "Subtries should have fewer than ~1K slots"); +} + +ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( + ThreadSafeTrieRawHashMapBase &&RHS) + : ContentAllocSize(RHS.ContentAllocSize), + ContentAllocAlign(RHS.ContentAllocAlign), + ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits), + NumSubtrieBits(RHS.NumSubtrieBits) { + // Steal the root from RHS. + ImplPtr = RHS.ImplPtr.exchange(nullptr); +} + +ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() { + assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()"); +} + +void ThreadSafeTrieRawHashMapBase::destroyImpl( + function_ref Destructor) { + std::unique_ptr Impl(ImplPtr.exchange(nullptr)); + if (!Impl) + return; + + // Destroy content nodes throughout trie. Avoid destroying any subtries since + // we need TrieNode::classof() to find the content nodes. + // + // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them + // facilitate sparse iteration here. + if (Destructor) + for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load()) + for (unsigned I = 0; I < Trie->size(); ++I) + if (auto *Content = dyn_cast_or_null(Trie->load(I))) + Destructor(Content->getValuePointer()); + + // Destroy the subtries. Incidentally, this destroys them in the reverse order + // of saving. + TrieSubtrie *Trie = Impl->getRoot()->Next; + while (Trie) { + TrieSubtrie *Next = Trie->Next.exchange(nullptr); + delete Trie; + Trie = Next; + } +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::getRoot() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + return PointerBase(Impl->getRoot()); +} + +unsigned ThreadSafeTrieRawHashMapBase::getStartBit( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode *)P.P)) + return S->StartBit; + return 0; +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumBits( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode *)P.P)) + return S->NumBits; + return 0; +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + auto *S = dyn_cast((TrieNode *)P.P); + if (!S) + return 0; + unsigned Num = 0; + for (unsigned I = 0, E = S->size(); I < E; ++I) + if (auto *E = S->load(I)) + ++Num; + return Num; +} + +std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return ""; + + auto *S = dyn_cast((TrieNode *)P.P); + if (!S || !S->IsSubtrie) + return ""; + + // Find a TrieContent node which has hash stored. Depth search following the + // first used slot until a TrieContent node is found. + TrieSubtrie *Current = S; + TrieContent *Node = nullptr; + while (Current) { + TrieSubtrie *Next = nullptr; + // Find first used slot in the trie. + for (unsigned I = 0, E = Current->size(); I < E; ++I) { + auto *S = Current->load(I); + if (!S) + continue; + + if (auto *Content = dyn_cast(S)) + Node = Content; + else if (auto *Sub = dyn_cast(S)) + Next = Sub; + break; + } + + // Found the node. + if (Node) + break; + + // Continue to the next level if the node is not found. + Current = Next; + } + + assert(Node && "malformed trie, cannot find TrieContent on leaf node"); + // The prefix for the current trie is the first `StartBit` of the content + // stored underneath this subtrie. + std::string Str; + raw_string_ostream SS(Str); + + unsigned StartFullBytes = (S->StartBit + 1) / 8 - 1; + SS << toHex(toStringRef(Node->getHash()).take_front(StartFullBytes), + /*LowerCase=*/true); + + // For the part of the prefix that doesn't fill a byte, print raw bit values. + std::string Bits; + for (unsigned I = StartFullBytes * 8, E = S->StartBit; I < E; ++I) { + unsigned Index = I / 8; + unsigned Offset = 7 - I % 8; + Bits.push_back('0' + ((Node->getHash()[Index] >> Offset) & 1)); + } + + if (!Bits.empty()) + SS << "[" << Bits << "]"; + + return SS.str(); +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return 0; + unsigned Num = 0; + for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load()) + ++Num; + return Num; +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::getNextTrie( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return PointerBase(); + auto *S = dyn_cast((TrieNode *)P.P); + if (!S) + return PointerBase(); + if (auto *E = S->Next.load()) + return PointerBase(E); + return PointerBase(); +} diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt index 745e4d9fb74a4..b0077d5b54a3e 100644 --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -86,6 +86,7 @@ add_llvm_unittest(ADTTests StringSetTest.cpp StringSwitchTest.cpp TinyPtrVectorTest.cpp + TrieRawHashMapTest.cpp TwineTest.cpp TypeSwitchTest.cpp TypeTraitsTest.cpp diff --git a/llvm/unittests/ADT/TrieRawHashMapTest.cpp b/llvm/unittests/ADT/TrieRawHashMapTest.cpp new file mode 100644 index 0000000000000..c9081f547812e --- /dev/null +++ b/llvm/unittests/ADT/TrieRawHashMapTest.cpp @@ -0,0 +1,346 @@ +//===- TrieRawHashMapTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/TrieRawHashMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/SHA1.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace llvm { +class TrieRawHashMapTestHelper { +public: + TrieRawHashMapTestHelper() = default; + + void setTrie(ThreadSafeTrieRawHashMapBase *T) { Trie = T; } + + ThreadSafeTrieRawHashMapBase::PointerBase getRoot() const { + return Trie->getRoot(); + } + unsigned getStartBit(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getStartBit(P); + } + unsigned getNumBits(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNumBits(P); + } + unsigned getNumSlotUsed(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNumSlotUsed(P); + } + unsigned getNumTries() const { return Trie->getNumTries(); } + std::string + getTriePrefixAsString(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getTriePrefixAsString(P); + } + ThreadSafeTrieRawHashMapBase::PointerBase + getNextTrie(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNextTrie(P); + } + +private: + ThreadSafeTrieRawHashMapBase *Trie = nullptr; +}; +} // namespace llvm + +namespace { +template +class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper, + public ::testing::Test { +public: + using NumType = DataType; + using HashType = std::array; + using TrieType = ThreadSafeTrieRawHashMap; + + TrieType &createTrie(size_t RootBits, size_t SubtrieBits) { + auto &Ret = Trie.emplace(RootBits, SubtrieBits); + TrieRawHashMapTestHelper::setTrie(&Ret); + return Ret; + } + + void destroyTrie() { Trie.reset(); } + ~SimpleTrieHashMapTest() { destroyTrie(); } + + // Use the number itself as hash to test the pathological case. + static HashType hash(uint64_t Num) { + uint64_t HashN = + llvm::support::endian::byte_swap(Num, llvm::endianness::big); + HashType Hash; + memcpy(&Hash[0], &HashN, sizeof(HashType)); + return Hash; + }; + +private: + std::optional Trie; +}; + +using SmallNodeTrieTest = SimpleTrieHashMapTest; + +TEST_F(SmallNodeTrieTest, TrieAllocation) { + NumType Numbers[] = { + 0x0, std::numeric_limits::max(), 0x1, 0x2, + 0x3, std::numeric_limits::max() - 1u, + }; + + unsigned ExpectedTries[] = { + 1, // Allocate Root. + 1, // Both on the root. + 64, // 0 and 1 sinks all the way down. + 64, // no new allocation needed. + 65, // need a new node between 2 and 3. + 65 + 63, // 63 new allocation to sink two big numbers all the way. + }; + + const char *ExpectedPrefix[] = { + "", // Root. + "", // Root. + "00000000000000[0000000]", + "00000000000000[0000000]", + "00000000000000[0000001]", + "ffffffffffffff[1111111]", + }; + + // Use root and subtrie sizes of 1 so this gets sunk quite deep. + auto &Trie = createTrie(/*RootBits=*/1, /*SubtrieBits=*/1); + + for (unsigned I = 0; I < 6; ++I) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(Numbers[I])); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(Numbers[I]), Numbers[I])); + EXPECT_EQ(getNumTries(), ExpectedTries[I]); + EXPECT_EQ(getTriePrefixAsString(getNextTrie(getRoot())), ExpectedPrefix[I]); + } +} + +TEST_F(SmallNodeTrieTest, TrieStructure) { + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + // Use root and subtrie sizes of 1 so this gets sunk quite deep. + auto &Trie = createTrie(/*RootBits=*/1, /*SubtrieBits=*/1); + + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1)); + EXPECT_EQ(N, Result->Data); + } + + // Check the trie so we can confirm the structure is correct. Each subtrie + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + // We should allocated a total of 64 SubTries for 64 bit hash. + ASSERT_EQ(getNumTries(), 64u); + // Check the root trie. Two slots and both are used. + ASSERT_EQ(getNumSlotUsed(getRoot()), 2u); + // Check last subtrie. + // Last allocated trie is the next node in the allocation chain. + auto LastAlloctedSubTrie = getNextTrie(getRoot()); + ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie), + "ffffffffffffff[1111110]"); + ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u); + ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u); + ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u); +} + +TEST_F(SmallNodeTrieTest, TrieStructureSmallFinalSubtrie) { + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + // Use subtrie size of 5 to avoid hitting 64 evenly, making the final subtrie + // small. + auto &Trie = createTrie(/*RootBits=*/8, /*SubtrieBits=*/5); + + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + ASSERT_TRUE(Lookup); + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1)); + EXPECT_EQ(N, Result->Data); + } + + // Check the trie so we can confirm the structure is correct. The root + // should have 2^8=256 slots, most subtries should have 2^5=32 slots, and the + // deepest subtrie should have 2^1=2 slots (since (64-8)mod(5)=1). + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + // 64 bit hash = 8 + 5 * 11 + 1, so 1 root, 11 8bit subtrie and 1 last level + // subtrie, 13 total. + ASSERT_EQ(getNumTries(), 13u); + // Check the root trie. Two slots and both are used. + ASSERT_EQ(getNumSlotUsed(getRoot()), 2u); + // Check last subtrie. + // Last allocated trie is the next node in the allocation chain. + auto LastAlloctedSubTrie = getNextTrie(getRoot()); + ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie), + "ffffffffffffff[1111110]"); + ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u); + ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u); + ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u); +} + +TEST_F(SmallNodeTrieTest, TrieDestructionLoop) { + // Test destroying large Trie. Make sure there is no recursion that can + // overflow the stack. + + // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate. + auto &Trie = createTrie(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + + // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug + // builds. + static constexpr uint64_t MaxN = 100000; + for (uint64_t N = 0; N != MaxN; ++N) { + HashType Hash = hash(N); + Trie.insert(TrieType::pointer(), TrieType::value_type(Hash, NumType{N})); + } + + // Destroy tries. If destruction is recursive and MaxN is high enough, these + // will both fail. + destroyTrie(); +} + +struct NumWithDestructorT { + uint64_t Num; + llvm::function_ref DestructorCallback; + ~NumWithDestructorT() { DestructorCallback(); } +}; + +using NodeWithDestructorTrieTest = SimpleTrieHashMapTest; + +TEST_F(NodeWithDestructorTrieTest, TrieDestructionLoop) { + // Test destroying large Trie. Make sure there is no recursion that can + // overflow the stack. + + // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate. + auto &Trie = createTrie(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + + // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug + // builds. + static constexpr uint64_t MaxN = 100000; + + uint64_t DestructorCalled = 0; + auto DtorCallback = [&DestructorCalled]() { ++DestructorCalled; }; + for (uint64_t N = 0; N != MaxN; ++N) { + HashType Hash = hash(N); + Trie.insert(TrieType::pointer(), + TrieType::value_type(Hash, NumType{N, DtorCallback})); + } + // Reset the count after all the temporaries get destroyed. + DestructorCalled = 0; + + // Destroy tries. If destruction is recursive and MaxN is high enough, these + // will both fail. + destroyTrie(); + + // Count the number of destructor calls during `destroyTrie()`. + ASSERT_EQ(DestructorCalled, MaxN); +} + +using NumStrNodeTrieTest = SimpleTrieHashMapTest; + +TEST_F(NumStrNodeTrieTest, TrieInsertLazy) { + for (unsigned RootBits : {2, 3, 6, 10}) { + for (unsigned SubtrieBits : {2, 3, 4}) { + auto &Trie = createTrie(RootBits, SubtrieBits); + for (int I = 0, E = 1000; I != E; ++I) { + TrieType::pointer Lookup; + HashType H = hash(I); + if (I & 1) + Lookup = Trie.find(H); + + auto insertNum = [&](uint64_t Num) { + std::string S = Twine(I).str(); + auto Hash = hash(Num); + return Trie.insertLazy( + Hash, [&](TrieType::LazyValueConstructor C) { C(std::move(S)); }); + }; + auto S1 = insertNum(I); + // The address of the Data should be the same. + EXPECT_EQ(&S1->Data, &insertNum(I)->Data); + + auto insertStr = [&](std::string S) { + int Num = std::stoi(S); + return insertNum(Num); + }; + std::string S2 = S1->Data; + // The address of the Data should be the same. + EXPECT_EQ(&S1->Data, &insertStr(S2)->Data); + } + for (int I = 0, E = 1000; I != E; ++I) { + std::string S = Twine(I).str(); + TrieType::pointer Lookup = Trie.find(hash(I)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(S, Lookup->Data); + } + } + } +} +} // end anonymous namespace