From 1818d8963e104eac7946268adb643eb3cff0cb8d Mon Sep 17 00:00:00 2001 From: Daniel Deptford Date: Tue, 11 Mar 2025 12:32:51 -0700 Subject: [PATCH] Implemented a more space efficient string<->integer map. (#9113) Summary: While investigating memory consumption, I noticed that the tiktoken loader was allocating 16MB of memory, mainly distributed over two std::unordered_maps. These two maps (the Encode and Decoder types) in Tiktoken are inverses of each other, and the std::string objects contained therein are clones of each other. The allocation of a node in each map is 40 bytes (on aarch64 Android): * 2x doubly linked list pointers at 8 bytes each * 1 std::uint64_t (8 bytes) * 1 std::string (12 bytes, std::strings contain an internal buffer for small strings). Each node actually allocates 48 bytes of usable memory, as the allocator aligns the allocations to 16 byte boundaries. This implementation of the string/integer map has several features: * Sharing of the data payload between two hash indices. * Variable sized integers, variable sized string length fields and best fit allocation of string data. That is to say, the data payload elements are variable sized. The implemented unit tests tracks the memory size allocated between the old std::unordered_map method and the new StringIntegerMap method, yielding a ~6x improvement in the memory allocated: ``` string integer map size = 2623343 unordered map size = 16078928 ``` There was a significant speedup when looking up strings, although looking up integers was about the same: ```------------------------------------------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------------------------------------------ BM_FindStringIntegerMapString/iterations:100 4722 us 4722 us 100 BM_FindStringIntegerMapInteger/iterations:100 529 us 529 us 100 BM_FindStringIntegerMapStringOptional/iterations:100 4714 us 4713 us 100 BM_FindStringIntegerMapIntegerOptional/iterations:100 537 us 536 us 100 BM_FindStdUnorderedMapString/iterations:100 7128 us 7127 us 100 BM_FindStdUnorderedMapInteger/iterations:100 536 us 536 us 100 ``` Reviewed By: swolchok, larryliu0820 Differential Revision: D69472841 --- extension/llm/tokenizer/string_integer_map.h | 569 ++++++++++++++++++ extension/llm/tokenizer/targets.bzl | 1 + extension/llm/tokenizer/test/CMakeLists.txt | 2 +- extension/llm/tokenizer/test/targets.bzl | 19 + .../test/test_string_integer_map.cpp | 318 ++++++++++ extension/llm/tokenizer/tiktoken.cpp | 105 ++-- extension/llm/tokenizer/tiktoken.h | 7 +- 7 files changed, 957 insertions(+), 64 deletions(-) create mode 100644 extension/llm/tokenizer/string_integer_map.h create mode 100644 extension/llm/tokenizer/test/test_string_integer_map.cpp diff --git a/extension/llm/tokenizer/string_integer_map.h b/extension/llm/tokenizer/string_integer_map.h new file mode 100644 index 00000000000..e8ff2d023e0 --- /dev/null +++ b/extension/llm/tokenizer/string_integer_map.h @@ -0,0 +1,569 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +/** + * StringIntegerMap is an immutable bidirectional map between strings and 64 bit + * unsigned integers. The element data is stored in a contiguous array and is + * shared between both the string buckets and the integer buckets, offering a + * compact representation. + * + * Variable sized integers are used internally, which are sized based on the + * data being stored. Custom hash functions are supported, with a stateful hash + * functor being optionally provided at construction time. + */ +template < + typename TStringHash = std::hash, + typename TIntegerHash = std::hash, + typename TAllocator = std::allocator> +class StringIntegerMap { + public: + /// @name Constructors + /// @{ + + /// Default constructor is deleted, as this container is intended to be + /// constructed with a map of strings to integers. + StringIntegerMap() = delete; + + /** + * Construct a StringIntegerMap from a map of strings to integers. Each + * string and integer in the map must be unique. + * @param map map of strings to integers + */ + explicit StringIntegerMap( + const std::unordered_map& map); + + /** + * Construct a StringIntegerMap from a map of strings to integers, explicitly + * intializing the integer and string hash objects. Each string and integer + * in the map must be unique. + * @param map map of strings to integers + */ + StringIntegerMap( + const std::unordered_map& map, + TStringHash string_hasher, + TIntegerHash integer_hasher); + + /// @} + /// @name Accessors + /// @{ + + /** + * Attempts to retrieve the integer mapped for the given string. + * @param str string to lookup + * @return a std::optional containing the integer if the string was found, + * std::nullopt otherwise + */ + std::optional tryGetInteger(std::string_view str) const; + + /** + * Attempts to retrieve the string mapped for the given integer. + * @param integer integer to lookup + * @return a std::optional containing the string if the integer was found, + * std::nullopt otherwise + */ + std::optional tryGetString(std::uint64_t integer) const; + + /// @} + + private: + template + class VariableSizedInteger { + public: + VariableSizedInteger() = default; + + explicit VariableSizedInteger(TLogical max_value) { + while (max_value != 0) { + ++byte_count_; + max_value >>= 8; + } + + mask_ = (TLogical(1) << (byte_count_ * 8)) - TLogical(1); + } + + std::size_t getByteCount() const { + return byte_count_; + } + + TLogical getMask() const { + return mask_; + } + + std::uint8_t* write(std::uint8_t* target, TLogical value) const { + std::memcpy(target, &value, byte_count_); + return target + byte_count_; + } + + TLogical read(const std::uint8_t* source) const { + TLogical value; + std::memcpy(&value, source, sizeof(TLogical)); + return value & mask_; + } + + private: + std::size_t byte_count_ = 0; + TLogical mask_ = 0; + }; + + bool tryGetInteger(std::string_view str, std::uint64_t& result) const; + + bool tryGetString(std::uint64_t integer, std::string_view& result) const; + + std::size_t getBucketIndex(std::string_view value) const; + + std::size_t getBucketIndex(std::uint64_t value) const; + + static std::uint8_t getSmallHash(std::size_t hash); + + /// Get the string data and string small hash stored in the element buffer at + /// the The hasher used for strings. + const TStringHash string_hasher_ = {}; + + /// The hasher used for integers. + const TIntegerHash integer_hasher_ = {}; + + /// String bucket references. + std::vector integer_bucket_data_; + + /// Integer bucket elements. + /// Laid out as: + /// struct { + /// std::uint64_t integer; - Physically using integer_ bytes. + /// std::size_t string_size; - Physically using string_size_ bytes + /// std::size_t string_offset; - Physically using string_offset_ bytes + /// } + std::vector integer_element_data_; + + /// String bucket references. + std::vector string_bucket_data_; + + /// String bucket elements. + /// Laid out as: + /// struct { + /// std::uint64_t integer; - Physically using integer_ bytes. + /// std::size_t string_size; - Physically using string_size_ bytes + /// std::uint8_t small_hash; - Using std::uint8_t bytes. + /// char string[string_size]; - String data, not zero terminated. + /// } + std::vector string_element_data_; + + /// Number of hash buckets to use. + std::size_t bucket_count_ = 0; + + /// Variable sized element offset info. + VariableSizedInteger element_offset_; + + /// Variable size string offset info. + VariableSizedInteger string_offset_; + + /// Variable sized string size info. + VariableSizedInteger string_size_; + + /// Variable sized integer info. + VariableSizedInteger integer_; +}; + +template +StringIntegerMap::StringIntegerMap( + const std::unordered_map& map) + : StringIntegerMap(map, TStringHash(), TIntegerHash()) {} + +template +StringIntegerMap::StringIntegerMap( + const std::unordered_map& map, + TStringHash string_hasher, + TIntegerHash integer_hasher) + : string_hasher_(string_hasher), integer_hasher_(integer_hasher) { + assert(map.size() <= std::numeric_limits::max()); + bucket_count_ = map.size(); + + struct BuilderElement { + std::uint64_t integer = 0; + std::string_view string; + std::size_t hash = 0; + std::size_t element_offset = 0; + }; + + std::vector builder_string_elements; + std::vector builder_integer_elements; + + // + // Calculate various item sizes and gather the builder elements. + // + + std::size_t largest_string_size = 0; + std::uint64_t largest_integer = 0; + std::size_t total_string_size = 0; + + for (const auto& [str, integer] : map) { + total_string_size += str.size(); + largest_string_size = std::max(largest_string_size, str.size()); + largest_integer = std::max(largest_integer, integer); + builder_string_elements.push_back({integer, str, string_hasher_(str)}); + builder_integer_elements.push_back( + {integer, str, integer_hasher_(integer)}); + } + + integer_ = VariableSizedInteger(largest_integer); + string_size_ = VariableSizedInteger(largest_string_size); + string_offset_ = VariableSizedInteger(total_string_size); + + const auto string_element_data_size = + ((integer_.getByteCount() + string_size_.getByteCount() + 1) * + map.size()) + + total_string_size; + const auto integer_element_size = integer_.getByteCount() + + string_offset_.getByteCount() + string_size_.getByteCount(); + const auto integer_element_data_size = integer_element_size * map.size(); + + element_offset_ = VariableSizedInteger( + std::max(string_element_data_size, integer_element_data_size)); + + string_bucket_data_.resize( + ((bucket_count_ + 1) * element_offset_.getByteCount()) + + sizeof(std::uint64_t)); + integer_bucket_data_.resize( + ((bucket_count_ + 1) * element_offset_.getByteCount()) + + sizeof(std::uint64_t)); + + // + // Set up terminal bucket indices. + // + + element_offset_.write( + string_bucket_data_.data() + + (bucket_count_ * element_offset_.getByteCount()), + string_element_data_size); + element_offset_.write( + integer_bucket_data_.data() + + (bucket_count_ * element_offset_.getByteCount()), + integer_element_data_size); + // + // Sort the builder elements. + // + + std::sort( + std::begin(builder_string_elements), + std::end(builder_string_elements), + [this](const BuilderElement& first, const BuilderElement& second) { + const auto first_bucket = first.hash % bucket_count_; + const auto second_bucket = second.hash % bucket_count_; + if (first_bucket == second_bucket) { + const auto first_small_hash = getSmallHash(first.hash); + const auto second_small_hash = getSmallHash(second.hash); + return first_small_hash < second_small_hash; + } + + return first_bucket < second_bucket; + }); + + std::sort( + std::begin(builder_integer_elements), + std::end(builder_integer_elements), + [this](const BuilderElement& first, const BuilderElement& second) { + const auto first_bucket = first.hash % bucket_count_; + const auto second_bucket = second.hash % bucket_count_; + if (first_bucket == second_bucket) { + return first.integer < second.integer; + } + + return first_bucket < second_bucket; + }); + + // + // Lay out the string elements and record their positions. + // + + std::unordered_map + string_element_byte_index_map; + string_element_data_.resize(string_element_data_size + sizeof(std::uint64_t)); + auto* string_element = string_element_data_.data(); + for (auto& builder_element : builder_string_elements) { + builder_element.element_offset = + string_element - string_element_data_.data(); + + auto insert_result = string_element_byte_index_map.insert( + {builder_element.string, builder_element.element_offset}); + assert(insert_result.second); + (void)insert_result; + + string_element = integer_.write(string_element, builder_element.integer); + string_element = + string_size_.write(string_element, builder_element.string.size()); + *string_element = getSmallHash(builder_element.hash); + string_element++; + std::memcpy( + string_element, + builder_element.string.data(), + builder_element.string.size()); + string_element += builder_element.string.size(); + assert( + string_element >= string_element_data_.data() && + string_element <= + string_element_data_.data() + string_element_data_size); + } + + // + // Lay out the integer elements. + // + + integer_element_data_.resize( + integer_element_data_size + sizeof(std::uint64_t)); + auto* integer_element = integer_element_data_.data(); + for (auto& builder_element : builder_integer_elements) { + builder_element.element_offset = + integer_element - integer_element_data_.data(); + auto string_element_byte_index_iter = + string_element_byte_index_map.find(builder_element.string); + assert( + string_element_byte_index_iter != + std::end(string_element_byte_index_map)); + integer_element = integer_.write(integer_element, builder_element.integer); + integer_element = + string_size_.write(integer_element, builder_element.string.size()); + integer_element = string_offset_.write( + integer_element, string_element_byte_index_iter->second); + assert( + integer_element >= integer_element_data_.data() && + integer_element <= + integer_element_data_.data() + integer_element_data_size); + } + + // + // Both the string elements and integer elements are laid out in order of + // their respective hashes. Generate the hash indexes for the string elements + // and integer elements. + // + + auto builder_string_elements_iter = std::begin(builder_string_elements); + auto builder_integer_elements_iter = std::begin(builder_integer_elements); + + for (std::size_t bucket_idx = 0; bucket_idx < bucket_count_; ++bucket_idx) { + auto* string_bucket = string_bucket_data_.data() + + (bucket_idx * element_offset_.getByteCount()); + if (builder_string_elements_iter != std::end(builder_string_elements)) { + element_offset_.write( + string_bucket, builder_string_elements_iter->element_offset); + } else { + element_offset_.write(string_bucket, string_element_data_size); + } + + auto* integer_bucket = integer_bucket_data_.data() + + (bucket_idx * element_offset_.getByteCount()); + if (builder_integer_elements_iter != std::end(builder_integer_elements)) { + element_offset_.write( + integer_bucket, builder_integer_elements_iter->element_offset); + } else { + element_offset_.write(integer_bucket, integer_element_data_size); + } + + // + // Advance the string element iterator past all string elements that map + // into this bucket. + // + + while (builder_string_elements_iter != std::end(builder_string_elements) && + getBucketIndex(builder_string_elements_iter->string) == bucket_idx) { + ++builder_string_elements_iter; + } + + // + // Advance the integer element index past all integer elements that map into + // this bucket. + // + + while ( + builder_integer_elements_iter != std::end(builder_integer_elements) && + getBucketIndex(builder_integer_elements_iter->integer) == bucket_idx) { + ++builder_integer_elements_iter; + } + } +} + +template +std::optional +StringIntegerMap::tryGetInteger( + std::string_view str) const { + std::uint64_t result; + return tryGetInteger(str, result) ? std::optional(result) + : std::nullopt; +} + +template +bool StringIntegerMap::tryGetInteger( + std::string_view str, + std::uint64_t& result) const { + if (bucket_count_ == 0) { + return false; + } + + const auto hash = string_hasher_(str); + const auto bucket_index = hash % bucket_count_; + const auto small_hash = getSmallHash(hash); + + const auto* bucket_data = string_bucket_data_.data() + + (bucket_index * element_offset_.getByteCount()); + const auto lower_element_offset = element_offset_.read(bucket_data); + const auto upper_element_offset = + element_offset_.read(bucket_data + element_offset_.getByteCount()); + + const auto integer_size = integer_.getByteCount(); + const auto string_size_size = string_size_.getByteCount(); + + std::size_t element_size = 0; + auto* element_data_end = string_element_data_.data() + upper_element_offset; + for (auto* element_data = string_element_data_.data() + lower_element_offset; + element_data < element_data_end; + element_data += element_size) { + // + // Read the string length. + // + + const auto element_string_length = + string_size_.read(element_data + integer_size); + element_size = integer_size + string_size_size + 1 + element_string_length; + + // + // Read the string small hash. + // + + const auto element_small_hash = + element_data[integer_size + string_size_size]; + if (element_small_hash < small_hash) { + continue; + } else if (element_small_hash > small_hash) { + break; + } + + // + // Get a view on the string for a full comparison. + // + + std::string_view element_string( + reinterpret_cast( + element_data + integer_size + string_size_size + 1), + element_string_length); + if (str == element_string) { + result = integer_.read(element_data); + return true; + } + } + + return false; +} + +template +std::optional +StringIntegerMap::tryGetString( + std::uint64_t integer) const { + std::string_view result; + return tryGetString(integer, result) ? std::optional(result) + : std::nullopt; +} + +template +bool StringIntegerMap::tryGetString( + std::uint64_t integer, + std::string_view& result) const { + if (bucket_count_ == 0) { + return false; + } + + const auto bucket_index = getBucketIndex(integer); + + const auto* bucket_data = integer_bucket_data_.data() + + (bucket_index * element_offset_.getByteCount()); + const auto lower_element_offset = element_offset_.read(bucket_data); + const auto upper_element_offset = + element_offset_.read(bucket_data + element_offset_.getByteCount()); + + const auto integer_element_size = integer_.getByteCount() + + string_offset_.getByteCount() + string_size_.getByteCount(); + auto* element_data_end = integer_element_data_.data() + upper_element_offset; + for (auto* element_data = integer_element_data_.data() + lower_element_offset; + element_data < element_data_end; + element_data += integer_element_size) { + const auto element_integer = integer_.read(element_data); + if (element_integer == integer) { + const auto element_string_size = + string_size_.read(element_data + integer_.getByteCount()); + const auto element_string_offset = string_offset_.read( + element_data + integer_.getByteCount() + string_size_.getByteCount()); + const auto* string_element = + string_element_data_.data() + element_string_offset; + const auto* string_data = reinterpret_cast( + string_element + integer_.getByteCount() + + string_size_.getByteCount() + 1); + result = std::string_view(string_data, element_string_size); + return true; + } else if (element_integer > integer) { + break; + } + } + + return false; +} + +template +std::size_t +StringIntegerMap::getBucketIndex( + std::string_view value) const { + return string_hasher_(value) % bucket_count_; +} + +template +std::size_t +StringIntegerMap::getBucketIndex( + std::uint64_t value) const { + return integer_hasher_(value) % bucket_count_; +} + +template +std::uint8_t +StringIntegerMap::getSmallHash( + std::size_t hash) { + const auto shift = (sizeof(std::size_t) * 8) - 8; + return static_cast(hash >> shift); +} + +template < + typename TStringHash = std::hash, + typename TIntegerHash = std::hash, + typename TAllocator = std::allocator> +struct StringIntegerMapTypeBuilder { + using Map = StringIntegerMap; + + template + using WithStringHash = + StringIntegerMapTypeBuilder; + + template + using WithIntegerHash = + StringIntegerMapTypeBuilder; + + template + using WithAllocator = + StringIntegerMapTypeBuilder; +}; +} // namespace executorch::extension::llm diff --git a/extension/llm/tokenizer/targets.bzl b/extension/llm/tokenizer/targets.bzl index 1a590c7876f..7b545054390 100644 --- a/extension/llm/tokenizer/targets.bzl +++ b/extension/llm/tokenizer/targets.bzl @@ -83,6 +83,7 @@ def define_common_targets(): exported_headers = [ "tiktoken.h", "base64.h", + "string_integer_map.h", ], exported_deps = [ ":tokenizer_header", diff --git a/extension/llm/tokenizer/test/CMakeLists.txt b/extension/llm/tokenizer/test/CMakeLists.txt index ffc37f9e46f..9ddcf48518e 100644 --- a/extension/llm/tokenizer/test/CMakeLists.txt +++ b/extension/llm/tokenizer/test/CMakeLists.txt @@ -19,7 +19,7 @@ include(${EXECUTORCH_ROOT}/build/Test.cmake) set(test_env "RESOURCES_PATH=${EXECUTORCH_ROOT}/extension/llm/tokenizer/test/resources") -set(_test_srcs test_bpe_tokenizer.cpp test_tiktoken.cpp) +set(_test_srcs test_bpe_tokenizer.cpp test_tiktoken.cpp test_string_integer_map.cpp) et_cxx_test( extension_llm_tokenizer_test SOURCES ${_test_srcs} EXTRA_LIBS diff --git a/extension/llm/tokenizer/test/targets.bzl b/extension/llm/tokenizer/test/targets.bzl index 2c314a98230..8755ae6273f 100644 --- a/extension/llm/tokenizer/test/targets.bzl +++ b/extension/llm/tokenizer/test/targets.bzl @@ -22,6 +22,20 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_string_integer_map", + srcs = [ + "test_string_integer_map.cpp", + ], + deps = [ + "//executorch/extension/llm/tokenizer:tiktoken", + ], + env = { + "RESOURCES_PATH": "$(location :resources)/resources", + }, + platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. + ) + runtime.cxx_test( name = "test_bpe_tokenizer", srcs = [ @@ -59,3 +73,8 @@ def define_common_targets(): "resources/**", ]), ) + + runtime.export_file( + name = "test_tiktoken_tokenizer_model", + src = "resources/test_tiktoken_tokenizer.model", + ) diff --git a/extension/llm/tokenizer/test/test_string_integer_map.cpp b/extension/llm/tokenizer/test/test_string_integer_map.cpp new file mode 100644 index 00000000000..24a9853429d --- /dev/null +++ b/extension/llm/tokenizer/test/test_string_integer_map.cpp @@ -0,0 +1,318 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef EXECUTORCH_FB_BUCK +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) || defined(WIN32) || defined(__linux__) +#define TEST_MEMORY_COMPARISON 1 + +#if defined(__APPLE__) +#include +#else +#include +#endif +#endif + +using namespace ::testing; +using ::executorch::extension::llm::StringIntegerMap; +using ::executorch::extension::llm::StringIntegerMapTypeBuilder; +using ::executorch::extension::llm::base64::decode; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; +using TokenizerMap = std::unordered_map; + +class StringIntegerMapTest : public Test { + public: + void SetUp() override { + ::executorch::runtime::runtime_init(); +#ifdef EXECUTORCH_FB_BUCK + modelPath_ = facebook::xplat::testing::getPathForTestResource( + "resources/test_tiktoken_tokenizer.model"); +#else + modelPath_ = std::getenv("RESOURCES_PATH") + + std::string("/test_tiktoken_tokenizer.model"); +#endif + } + + Result loadModel() { + std::ifstream file(modelPath_); + ET_CHECK_OR_RETURN_ERROR( + file, + InvalidArgument, + "failed to open encoder file: %s", + modelPath_.c_str()); + + TokenizerMap model; + for (std::string line; std::getline(file, line);) { + if (line.empty()) { + continue; + } + + auto pos = line.find(' '); + auto token = ET_UNWRAP(decode({line.data(), pos})); + uint64_t rank = 0; + try { + rank = std::stoul(line.substr(pos + 1)); + } catch (const std::exception&) { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidArgument, "invalid encoder rank: %s", line.c_str()); + } + model[token] = rank; + } + + return model; + } + + std::string modelPath_; +}; + +#if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON + +class TrackingAllocatorBase { + public: + static void reset(); + static std::size_t getSize(); + + protected: + static void* allocate(std::size_t size); + static void deallocate(void* ptr); + + static std::size_t size_; +}; + +void TrackingAllocatorBase::reset() { + size_ = 0; +} + +std::size_t TrackingAllocatorBase::getSize() { + return size_; +} + +void* TrackingAllocatorBase::allocate(std::size_t size) { + void* ptr = malloc(size); + if (!ptr) { + return nullptr; + } + +#if defined(WIN32) + size_ += _msize(ptr); +#elif defined(__APPLE__) + size_ += malloc_size(const_cast(ptr)); +#else + size_ += malloc_usable_size(ptr); +#endif + + return ptr; +} + +void TrackingAllocatorBase::deallocate(void* ptr) { + if (!ptr) { + return; + } + +#if defined(WIN32) + size_ -= _msize(ptr); +#elif defined(__APPLE__) + size_ -= malloc_size(const_cast(ptr)); +#else + size_ -= malloc_usable_size(ptr); +#endif + + free(ptr); +} + +std::size_t TrackingAllocatorBase::size_ = 0; + +template +class TrackingAllocator : public TrackingAllocatorBase { + public: + using value_type = T; + TrackingAllocator() noexcept = default; + template + explicit TrackingAllocator(TrackingAllocator const&) noexcept {} + + value_type* allocate(std::size_t count) { + return static_cast( + TrackingAllocatorBase::allocate(count * sizeof(value_type))); // NOLINT + } + + void deallocate(value_type* ptr, std::size_t /*count*/) noexcept { + TrackingAllocatorBase::deallocate(ptr); + } +}; + +template +bool operator==( + TrackingAllocator const&, + TrackingAllocator const&) noexcept { + return true; +} + +template +bool operator!=( + TrackingAllocator const& lhs, + TrackingAllocator const& rhs) noexcept { + return !(lhs == rhs); +} + +#endif + +TEST_F(StringIntegerMapTest, CreateFromModel) { + const auto res = loadModel(); + ASSERT_EQ(res.ok(), true); + const auto& model = res.get(); + StringIntegerMap map(model); + + for (const auto& [model_key, model_value] : model) { + EXPECT_THAT(map.tryGetInteger(model_key), testing::Optional(model_value)) + << model_key; + EXPECT_THAT(map.tryGetString(model_value), testing::Optional(model_key)) + << model_value; + } + + EXPECT_FALSE(map.tryGetInteger("Ich weiß nicht")); + EXPECT_FALSE(map.tryGetString(999999999)); +} + +#if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON + +TEST_F(StringIntegerMapTest, MemoryConsumptionComparison) { + TrackingAllocatorBase::reset(); + EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); + + const auto res = loadModel(); + ASSERT_EQ(res.ok(), true); + const auto& model = res.get(); + + std::size_t string_integer_map_size = 0; + { + typename StringIntegerMapTypeBuilder<>::WithAllocator< + TrackingAllocator>::Map map(model); + string_integer_map_size = TrackingAllocatorBase::getSize(); + } + + EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); + + std::size_t unordered_map_size = 0; + { + std::unordered_map< + std::string, + std::uint64_t, + std::hash, + std::equal_to, + TrackingAllocator>> + strings_to_ints; + std::unordered_map< + std::uint64_t, + std::string, + std::hash, + std::equal_to, + TrackingAllocator>> + ints_to_strings; + for (const auto& [k, v] : model) { + strings_to_ints.emplace(k, v); + ints_to_strings.emplace(v, k); + } + + unordered_map_size = TrackingAllocatorBase::getSize(); + } + + EXPECT_LT(string_integer_map_size, unordered_map_size); + +#if 1 + std::cout << "string integer map size = " << string_integer_map_size + << std::endl; + std::cout << "unordered map size = " << unordered_map_size << std::endl; +#endif +} + +#endif + +template +struct FixedHash { + std::size_t operator()(const std::string_view& str) const { + if (str.empty()) { + return hash_offset; + } + + return str.size() - 1 + hash_offset; + } + + std::size_t operator()(std::uint64_t value) const { + if (value == 0) { + return hash_offset; + } + + return static_cast(std::log10(value)) + hash_offset; + } +}; + +template +class StringIntegerMapHashTest : public Test { + public: + using Container = typename StringIntegerMapTypeBuilder<>::WithIntegerHash< + THash>::template WithStringHash::Map; +}; + +using StringIntegerMapHashTestTypes = + ::testing::Types, FixedHash<1>, FixedHash<2>, FixedHash<3>>; +TYPED_TEST_SUITE(StringIntegerMapHashTest, StringIntegerMapHashTestTypes); + +TYPED_TEST(StringIntegerMapHashTest, HashCollisions) { + std::unordered_map source = { + {"a", 0}, + {"b", 1}, + {"c", 2}, + {"d", 3}, + }; + + typename TestFixture::Container map(source); + + // + // Check that the strings exist in the map. + // + + EXPECT_THAT(map.tryGetInteger("a"), Optional(0ull)); + EXPECT_THAT(map.tryGetInteger("b"), Optional(1ull)); + EXPECT_THAT(map.tryGetInteger("c"), Optional(2ull)); + EXPECT_THAT(map.tryGetInteger("d"), Optional(3ull)); + + EXPECT_FALSE(map.tryGetInteger("e")); + + // + // Check that the integers exist in the map. + // + + EXPECT_THAT(map.tryGetString(0), Optional(std::string_view("a"))); + EXPECT_THAT(map.tryGetString(1), Optional(std::string_view("b"))); + EXPECT_THAT(map.tryGetString(2), Optional(std::string_view("c"))); + EXPECT_THAT(map.tryGetString(3), Optional(std::string_view("d"))); + + EXPECT_FALSE(map.tryGetString(4)); + + // + // Test a lookup into the next bucket (which should be empty). + // + + EXPECT_FALSE(map.tryGetInteger("aa")); + EXPECT_FALSE(map.tryGetInteger("aaa")); + EXPECT_FALSE(map.tryGetInteger("aaaa")); + + EXPECT_FALSE(map.tryGetString(10)); + EXPECT_FALSE(map.tryGetString(100)); + EXPECT_FALSE(map.tryGetString(1000)); +} diff --git a/extension/llm/tokenizer/tiktoken.cpp b/extension/llm/tokenizer/tiktoken.cpp index f99ac2e955e..725a3fe453d 100644 --- a/extension/llm/tokenizer/tiktoken.cpp +++ b/extension/llm/tokenizer/tiktoken.cpp @@ -29,7 +29,9 @@ #include #include #include +#include #include +#include using ::executorch::runtime::Error; using ::executorch::runtime::Result; @@ -97,6 +99,7 @@ static Result _load_encoder(const std::string& path) { Encoder encoder; std::string line; + std::unordered_set ranks; while (std::getline(file, line)) { auto [token, rank] = ET_UNWRAP(_parse(line)); @@ -105,28 +108,20 @@ static Result _load_encoder(const std::string& path) { InvalidArgument, "duplicate item: %s", line.c_str()); - } - return encoder; -} - -static Result _build_decoder(const Encoder& encoder) { - Decoder decoder; - for (const auto& [k, v] : encoder) { - decoder.emplace(v, k); + ET_CHECK_OR_RETURN_ERROR( + ranks.insert(rank).second, + InvalidArgument, + "duplicate rank: %s", + line.c_str()); } - ET_CHECK_OR_RETURN_ERROR( - encoder.size() == decoder.size(), - InvalidArgument, - "duplicate items in encoder"); - - return decoder; + return encoder; } static std::vector _byte_pair_merge( const std::string& piece, - const std::unordered_map& ranks, + const StringIntegerMap<>& ranks, std::function func) { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. @@ -145,10 +140,7 @@ static std::vector _byte_pair_merge( auto s = parts[start_idx].first; auto e = parts[start_idx + skip + 2].first; auto key = piece.substr(s, e - s); - auto iter = ranks.find(key); - if (iter != ranks.end()) { - return iter->second; - } + return ranks.tryGetInteger(key); } return std::nullopt; }; @@ -230,11 +222,11 @@ static std::vector _byte_pair_merge( static std::vector _byte_pair_encode( const std::string& piece, - const Encoder& encoder) { + const StringIntegerMap<>& tokenizer) { if (piece.size() == 1) { - auto iter = encoder.find(piece); - if (iter != encoder.end()) { - return std::vector({iter->second}); + const auto result = tokenizer.tryGetInteger(piece); + if (result) { + return std::vector(*result); } else { // TODO: is it possible? return {}; @@ -242,11 +234,11 @@ static std::vector _byte_pair_encode( } return _byte_pair_merge( - piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) { + piece, tokenizer, [&piece, &tokenizer](uint64_t start, uint64_t stop) { std::string key = piece.substr(start, stop - start); - auto iter = encoder.find(key); - if (iter != encoder.end()) { - return iter->second; + const auto result = tokenizer.tryGetInteger(key); + if (result) { + return *result; } else { // TODO: what if key does not exist? Should we return `unknown`? // assert(false); // ?? @@ -278,7 +270,7 @@ Tiktoken::_split_with_allowed_special_token( break; } - if (allowed_special.count(special) == 1) { + if (allowed_special.tryGetInteger(special).has_value()) { // Found an allowed special token, split the text with it. #if __cplusplus >= 202002L return std::make_pair( @@ -302,13 +294,13 @@ void Tiktoken::_encode( std::string piece; assert(_regex); while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { - auto iter = _encoder.find(piece); - if (iter != _encoder.end()) { + const auto result = _token_map->tryGetInteger(piece); + if (result) { last_piece_token_len = 1; - ret.push_back(iter->second); + ret.push_back(*result); continue; } - auto tokens = _byte_pair_encode(piece, _encoder); + auto tokens = _byte_pair_encode(piece, *_token_map); last_piece_token_len = tokens.size(); ret.insert(ret.end(), tokens.begin(), tokens.end()); } @@ -328,16 +320,14 @@ std::pair, uint64_t> Tiktoken::_encode_with_special_token( _encode(sub_input, tokens, last_piece_token_len); if (special) { - uint64_t token = 0; - try { - token = _special_token_encoder.at(*special); - } catch (const std::out_of_range&) { + const auto result = _special_token_map->tryGetInteger(*special); + if (!result) { // Should never go here, since special pattern includes all special // chars. ET_CHECK_MSG(false, "unknown special token: %s", special->c_str()); } - tokens.push_back(token); + tokens.push_back(*result); last_piece_token_len = 0; } else { break; @@ -380,11 +370,10 @@ Tiktoken::Tiktoken( } Error Tiktoken::load(const std::string& path) { - _encoder = ET_UNWRAP(_load_encoder(path)); - _special_token_encoder = _build_special_token_encoder(_encoder.size()); - - _decoder = ET_UNWRAP(_build_decoder(_encoder)); - _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder)); + auto encoder = ET_UNWRAP(_load_encoder(path)); + _token_map.emplace(StringIntegerMap<>(encoder)); + auto special_token_encoder = _build_special_token_encoder(encoder.size()); + _special_token_map.emplace(StringIntegerMap<>(special_token_encoder)); _regex = _create_regex(_pattern); // Warmup re2 as it is slow on the first run, void the return value as it's @@ -392,14 +381,14 @@ Error Tiktoken::load(const std::string& path) { // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 (void)_regex->ReverseProgramSize(); - _special_token_regex = _build_special_token_regex(_special_token_encoder); + _special_token_regex = _build_special_token_regex(special_token_encoder); // Same as above, warm up re2 (void)_special_token_regex->ReverseProgramSize(); // initialize vocab_size, bos_tok, eos_tok - vocab_size_ = _encoder.size() + _special_token_encoder.size(); - bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index)); - eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index)); + vocab_size_ = encoder.size() + special_token_encoder.size(); + bos_tok_ = special_token_encoder.at(_special_tokens->at(_bos_token_index)); + eos_tok_ = special_token_encoder.at(_special_tokens->at(_eos_token_index)); initialized_ = true; return Error::Ok; @@ -410,7 +399,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const { if (!initialized_) { return Error::NotSupported; } - auto res = _encode_with_special_token(text, _special_token_encoder).first; + auto res = _encode_with_special_token(text, *_special_token_map).first; for (auto i = 0; i < bos; ++i) { res.insert(res.begin(), bos_tok_); } @@ -425,21 +414,19 @@ Result Tiktoken::decode(uint64_t prev, uint64_t cur) const { ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur)); std::string ret; - std::string token_bytes; - auto iter = _decoder.find(cur); - if (iter != _decoder.end()) { - token_bytes = iter->second; - } else { - iter = _special_token_decoder.find(cur); - if (iter != _special_token_decoder.end()) { - token_bytes = iter->second; - } else { + std::string_view token_bytes; + auto result = _token_map->tryGetString(cur); + if (!result) { + result = _special_token_map->tryGetString(cur); + if (!result) { ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur); + } else { + token_bytes = *result; } + } else { + token_bytes = *result; } - ret += token_bytes; - - return ret; + return std::string(token_bytes); } // -------------------------public method end------------------------------- diff --git a/extension/llm/tokenizer/tiktoken.h b/extension/llm/tokenizer/tiktoken.h index 5201c07a184..d7d93b27597 100644 --- a/extension/llm/tokenizer/tiktoken.h +++ b/extension/llm/tokenizer/tiktoken.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -68,10 +69,8 @@ class ET_EXPERIMENTAL Tiktoken : public Tokenizer { // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. const std::string _pattern = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; - Encoder _encoder; - Encoder _special_token_encoder; - Decoder _decoder; - Decoder _special_token_decoder; + std::optional> _token_map; + std::optional> _special_token_map; Re2UPtr _regex; Re2UPtr _special_token_regex;