diff --git a/extension/llm/tokenizer/string_integer_map.h b/extension/llm/tokenizer/string_integer_map.h new file mode 100644 index 00000000000..e8ff2d023e0 --- /dev/null +++ b/extension/llm/tokenizer/string_integer_map.h @@ -0,0 +1,569 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +/** + * StringIntegerMap is an immutable bidirectional map between strings and 64 bit + * unsigned integers. The element data is stored in a contiguous array and is + * shared between both the string buckets and the integer buckets, offering a + * compact representation. + * + * Variable sized integers are used internally, which are sized based on the + * data being stored. Custom hash functions are supported, with a stateful hash + * functor being optionally provided at construction time. + */ +template < + typename TStringHash = std::hash, + typename TIntegerHash = std::hash, + typename TAllocator = std::allocator> +class StringIntegerMap { + public: + /// @name Constructors + /// @{ + + /// Default constructor is deleted, as this container is intended to be + /// constructed with a map of strings to integers. + StringIntegerMap() = delete; + + /** + * Construct a StringIntegerMap from a map of strings to integers. Each + * string and integer in the map must be unique. + * @param map map of strings to integers + */ + explicit StringIntegerMap( + const std::unordered_map& map); + + /** + * Construct a StringIntegerMap from a map of strings to integers, explicitly + * intializing the integer and string hash objects. Each string and integer + * in the map must be unique. + * @param map map of strings to integers + */ + StringIntegerMap( + const std::unordered_map& map, + TStringHash string_hasher, + TIntegerHash integer_hasher); + + /// @} + /// @name Accessors + /// @{ + + /** + * Attempts to retrieve the integer mapped for the given string. + * @param str string to lookup + * @return a std::optional containing the integer if the string was found, + * std::nullopt otherwise + */ + std::optional tryGetInteger(std::string_view str) const; + + /** + * Attempts to retrieve the string mapped for the given integer. + * @param integer integer to lookup + * @return a std::optional containing the string if the integer was found, + * std::nullopt otherwise + */ + std::optional tryGetString(std::uint64_t integer) const; + + /// @} + + private: + template + class VariableSizedInteger { + public: + VariableSizedInteger() = default; + + explicit VariableSizedInteger(TLogical max_value) { + while (max_value != 0) { + ++byte_count_; + max_value >>= 8; + } + + mask_ = (TLogical(1) << (byte_count_ * 8)) - TLogical(1); + } + + std::size_t getByteCount() const { + return byte_count_; + } + + TLogical getMask() const { + return mask_; + } + + std::uint8_t* write(std::uint8_t* target, TLogical value) const { + std::memcpy(target, &value, byte_count_); + return target + byte_count_; + } + + TLogical read(const std::uint8_t* source) const { + TLogical value; + std::memcpy(&value, source, sizeof(TLogical)); + return value & mask_; + } + + private: + std::size_t byte_count_ = 0; + TLogical mask_ = 0; + }; + + bool tryGetInteger(std::string_view str, std::uint64_t& result) const; + + bool tryGetString(std::uint64_t integer, std::string_view& result) const; + + std::size_t getBucketIndex(std::string_view value) const; + + std::size_t getBucketIndex(std::uint64_t value) const; + + static std::uint8_t getSmallHash(std::size_t hash); + + /// Get the string data and string small hash stored in the element buffer at + /// the The hasher used for strings. + const TStringHash string_hasher_ = {}; + + /// The hasher used for integers. + const TIntegerHash integer_hasher_ = {}; + + /// String bucket references. + std::vector integer_bucket_data_; + + /// Integer bucket elements. + /// Laid out as: + /// struct { + /// std::uint64_t integer; - Physically using integer_ bytes. + /// std::size_t string_size; - Physically using string_size_ bytes + /// std::size_t string_offset; - Physically using string_offset_ bytes + /// } + std::vector integer_element_data_; + + /// String bucket references. + std::vector string_bucket_data_; + + /// String bucket elements. + /// Laid out as: + /// struct { + /// std::uint64_t integer; - Physically using integer_ bytes. + /// std::size_t string_size; - Physically using string_size_ bytes + /// std::uint8_t small_hash; - Using std::uint8_t bytes. + /// char string[string_size]; - String data, not zero terminated. + /// } + std::vector string_element_data_; + + /// Number of hash buckets to use. + std::size_t bucket_count_ = 0; + + /// Variable sized element offset info. + VariableSizedInteger element_offset_; + + /// Variable size string offset info. + VariableSizedInteger string_offset_; + + /// Variable sized string size info. + VariableSizedInteger string_size_; + + /// Variable sized integer info. + VariableSizedInteger integer_; +}; + +template +StringIntegerMap::StringIntegerMap( + const std::unordered_map& map) + : StringIntegerMap(map, TStringHash(), TIntegerHash()) {} + +template +StringIntegerMap::StringIntegerMap( + const std::unordered_map& map, + TStringHash string_hasher, + TIntegerHash integer_hasher) + : string_hasher_(string_hasher), integer_hasher_(integer_hasher) { + assert(map.size() <= std::numeric_limits::max()); + bucket_count_ = map.size(); + + struct BuilderElement { + std::uint64_t integer = 0; + std::string_view string; + std::size_t hash = 0; + std::size_t element_offset = 0; + }; + + std::vector builder_string_elements; + std::vector builder_integer_elements; + + // + // Calculate various item sizes and gather the builder elements. + // + + std::size_t largest_string_size = 0; + std::uint64_t largest_integer = 0; + std::size_t total_string_size = 0; + + for (const auto& [str, integer] : map) { + total_string_size += str.size(); + largest_string_size = std::max(largest_string_size, str.size()); + largest_integer = std::max(largest_integer, integer); + builder_string_elements.push_back({integer, str, string_hasher_(str)}); + builder_integer_elements.push_back( + {integer, str, integer_hasher_(integer)}); + } + + integer_ = VariableSizedInteger(largest_integer); + string_size_ = VariableSizedInteger(largest_string_size); + string_offset_ = VariableSizedInteger(total_string_size); + + const auto string_element_data_size = + ((integer_.getByteCount() + string_size_.getByteCount() + 1) * + map.size()) + + total_string_size; + const auto integer_element_size = integer_.getByteCount() + + string_offset_.getByteCount() + string_size_.getByteCount(); + const auto integer_element_data_size = integer_element_size * map.size(); + + element_offset_ = VariableSizedInteger( + std::max(string_element_data_size, integer_element_data_size)); + + string_bucket_data_.resize( + ((bucket_count_ + 1) * element_offset_.getByteCount()) + + sizeof(std::uint64_t)); + integer_bucket_data_.resize( + ((bucket_count_ + 1) * element_offset_.getByteCount()) + + sizeof(std::uint64_t)); + + // + // Set up terminal bucket indices. + // + + element_offset_.write( + string_bucket_data_.data() + + (bucket_count_ * element_offset_.getByteCount()), + string_element_data_size); + element_offset_.write( + integer_bucket_data_.data() + + (bucket_count_ * element_offset_.getByteCount()), + integer_element_data_size); + // + // Sort the builder elements. + // + + std::sort( + std::begin(builder_string_elements), + std::end(builder_string_elements), + [this](const BuilderElement& first, const BuilderElement& second) { + const auto first_bucket = first.hash % bucket_count_; + const auto second_bucket = second.hash % bucket_count_; + if (first_bucket == second_bucket) { + const auto first_small_hash = getSmallHash(first.hash); + const auto second_small_hash = getSmallHash(second.hash); + return first_small_hash < second_small_hash; + } + + return first_bucket < second_bucket; + }); + + std::sort( + std::begin(builder_integer_elements), + std::end(builder_integer_elements), + [this](const BuilderElement& first, const BuilderElement& second) { + const auto first_bucket = first.hash % bucket_count_; + const auto second_bucket = second.hash % bucket_count_; + if (first_bucket == second_bucket) { + return first.integer < second.integer; + } + + return first_bucket < second_bucket; + }); + + // + // Lay out the string elements and record their positions. + // + + std::unordered_map + string_element_byte_index_map; + string_element_data_.resize(string_element_data_size + sizeof(std::uint64_t)); + auto* string_element = string_element_data_.data(); + for (auto& builder_element : builder_string_elements) { + builder_element.element_offset = + string_element - string_element_data_.data(); + + auto insert_result = string_element_byte_index_map.insert( + {builder_element.string, builder_element.element_offset}); + assert(insert_result.second); + (void)insert_result; + + string_element = integer_.write(string_element, builder_element.integer); + string_element = + string_size_.write(string_element, builder_element.string.size()); + *string_element = getSmallHash(builder_element.hash); + string_element++; + std::memcpy( + string_element, + builder_element.string.data(), + builder_element.string.size()); + string_element += builder_element.string.size(); + assert( + string_element >= string_element_data_.data() && + string_element <= + string_element_data_.data() + string_element_data_size); + } + + // + // Lay out the integer elements. + // + + integer_element_data_.resize( + integer_element_data_size + sizeof(std::uint64_t)); + auto* integer_element = integer_element_data_.data(); + for (auto& builder_element : builder_integer_elements) { + builder_element.element_offset = + integer_element - integer_element_data_.data(); + auto string_element_byte_index_iter = + string_element_byte_index_map.find(builder_element.string); + assert( + string_element_byte_index_iter != + std::end(string_element_byte_index_map)); + integer_element = integer_.write(integer_element, builder_element.integer); + integer_element = + string_size_.write(integer_element, builder_element.string.size()); + integer_element = string_offset_.write( + integer_element, string_element_byte_index_iter->second); + assert( + integer_element >= integer_element_data_.data() && + integer_element <= + integer_element_data_.data() + integer_element_data_size); + } + + // + // Both the string elements and integer elements are laid out in order of + // their respective hashes. Generate the hash indexes for the string elements + // and integer elements. + // + + auto builder_string_elements_iter = std::begin(builder_string_elements); + auto builder_integer_elements_iter = std::begin(builder_integer_elements); + + for (std::size_t bucket_idx = 0; bucket_idx < bucket_count_; ++bucket_idx) { + auto* string_bucket = string_bucket_data_.data() + + (bucket_idx * element_offset_.getByteCount()); + if (builder_string_elements_iter != std::end(builder_string_elements)) { + element_offset_.write( + string_bucket, builder_string_elements_iter->element_offset); + } else { + element_offset_.write(string_bucket, string_element_data_size); + } + + auto* integer_bucket = integer_bucket_data_.data() + + (bucket_idx * element_offset_.getByteCount()); + if (builder_integer_elements_iter != std::end(builder_integer_elements)) { + element_offset_.write( + integer_bucket, builder_integer_elements_iter->element_offset); + } else { + element_offset_.write(integer_bucket, integer_element_data_size); + } + + // + // Advance the string element iterator past all string elements that map + // into this bucket. + // + + while (builder_string_elements_iter != std::end(builder_string_elements) && + getBucketIndex(builder_string_elements_iter->string) == bucket_idx) { + ++builder_string_elements_iter; + } + + // + // Advance the integer element index past all integer elements that map into + // this bucket. + // + + while ( + builder_integer_elements_iter != std::end(builder_integer_elements) && + getBucketIndex(builder_integer_elements_iter->integer) == bucket_idx) { + ++builder_integer_elements_iter; + } + } +} + +template +std::optional +StringIntegerMap::tryGetInteger( + std::string_view str) const { + std::uint64_t result; + return tryGetInteger(str, result) ? std::optional(result) + : std::nullopt; +} + +template +bool StringIntegerMap::tryGetInteger( + std::string_view str, + std::uint64_t& result) const { + if (bucket_count_ == 0) { + return false; + } + + const auto hash = string_hasher_(str); + const auto bucket_index = hash % bucket_count_; + const auto small_hash = getSmallHash(hash); + + const auto* bucket_data = string_bucket_data_.data() + + (bucket_index * element_offset_.getByteCount()); + const auto lower_element_offset = element_offset_.read(bucket_data); + const auto upper_element_offset = + element_offset_.read(bucket_data + element_offset_.getByteCount()); + + const auto integer_size = integer_.getByteCount(); + const auto string_size_size = string_size_.getByteCount(); + + std::size_t element_size = 0; + auto* element_data_end = string_element_data_.data() + upper_element_offset; + for (auto* element_data = string_element_data_.data() + lower_element_offset; + element_data < element_data_end; + element_data += element_size) { + // + // Read the string length. + // + + const auto element_string_length = + string_size_.read(element_data + integer_size); + element_size = integer_size + string_size_size + 1 + element_string_length; + + // + // Read the string small hash. + // + + const auto element_small_hash = + element_data[integer_size + string_size_size]; + if (element_small_hash < small_hash) { + continue; + } else if (element_small_hash > small_hash) { + break; + } + + // + // Get a view on the string for a full comparison. + // + + std::string_view element_string( + reinterpret_cast( + element_data + integer_size + string_size_size + 1), + element_string_length); + if (str == element_string) { + result = integer_.read(element_data); + return true; + } + } + + return false; +} + +template +std::optional +StringIntegerMap::tryGetString( + std::uint64_t integer) const { + std::string_view result; + return tryGetString(integer, result) ? std::optional(result) + : std::nullopt; +} + +template +bool StringIntegerMap::tryGetString( + std::uint64_t integer, + std::string_view& result) const { + if (bucket_count_ == 0) { + return false; + } + + const auto bucket_index = getBucketIndex(integer); + + const auto* bucket_data = integer_bucket_data_.data() + + (bucket_index * element_offset_.getByteCount()); + const auto lower_element_offset = element_offset_.read(bucket_data); + const auto upper_element_offset = + element_offset_.read(bucket_data + element_offset_.getByteCount()); + + const auto integer_element_size = integer_.getByteCount() + + string_offset_.getByteCount() + string_size_.getByteCount(); + auto* element_data_end = integer_element_data_.data() + upper_element_offset; + for (auto* element_data = integer_element_data_.data() + lower_element_offset; + element_data < element_data_end; + element_data += integer_element_size) { + const auto element_integer = integer_.read(element_data); + if (element_integer == integer) { + const auto element_string_size = + string_size_.read(element_data + integer_.getByteCount()); + const auto element_string_offset = string_offset_.read( + element_data + integer_.getByteCount() + string_size_.getByteCount()); + const auto* string_element = + string_element_data_.data() + element_string_offset; + const auto* string_data = reinterpret_cast( + string_element + integer_.getByteCount() + + string_size_.getByteCount() + 1); + result = std::string_view(string_data, element_string_size); + return true; + } else if (element_integer > integer) { + break; + } + } + + return false; +} + +template +std::size_t +StringIntegerMap::getBucketIndex( + std::string_view value) const { + return string_hasher_(value) % bucket_count_; +} + +template +std::size_t +StringIntegerMap::getBucketIndex( + std::uint64_t value) const { + return integer_hasher_(value) % bucket_count_; +} + +template +std::uint8_t +StringIntegerMap::getSmallHash( + std::size_t hash) { + const auto shift = (sizeof(std::size_t) * 8) - 8; + return static_cast(hash >> shift); +} + +template < + typename TStringHash = std::hash, + typename TIntegerHash = std::hash, + typename TAllocator = std::allocator> +struct StringIntegerMapTypeBuilder { + using Map = StringIntegerMap; + + template + using WithStringHash = + StringIntegerMapTypeBuilder; + + template + using WithIntegerHash = + StringIntegerMapTypeBuilder; + + template + using WithAllocator = + StringIntegerMapTypeBuilder; +}; +} // namespace executorch::extension::llm diff --git a/extension/llm/tokenizer/targets.bzl b/extension/llm/tokenizer/targets.bzl index 1a590c7876f..7b545054390 100644 --- a/extension/llm/tokenizer/targets.bzl +++ b/extension/llm/tokenizer/targets.bzl @@ -83,6 +83,7 @@ def define_common_targets(): exported_headers = [ "tiktoken.h", "base64.h", + "string_integer_map.h", ], exported_deps = [ ":tokenizer_header", diff --git a/extension/llm/tokenizer/test/CMakeLists.txt b/extension/llm/tokenizer/test/CMakeLists.txt index ffc37f9e46f..9ddcf48518e 100644 --- a/extension/llm/tokenizer/test/CMakeLists.txt +++ b/extension/llm/tokenizer/test/CMakeLists.txt @@ -19,7 +19,7 @@ include(${EXECUTORCH_ROOT}/build/Test.cmake) set(test_env "RESOURCES_PATH=${EXECUTORCH_ROOT}/extension/llm/tokenizer/test/resources") -set(_test_srcs test_bpe_tokenizer.cpp test_tiktoken.cpp) +set(_test_srcs test_bpe_tokenizer.cpp test_tiktoken.cpp test_string_integer_map.cpp) et_cxx_test( extension_llm_tokenizer_test SOURCES ${_test_srcs} EXTRA_LIBS diff --git a/extension/llm/tokenizer/test/targets.bzl b/extension/llm/tokenizer/test/targets.bzl index 2c314a98230..8755ae6273f 100644 --- a/extension/llm/tokenizer/test/targets.bzl +++ b/extension/llm/tokenizer/test/targets.bzl @@ -22,6 +22,20 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_string_integer_map", + srcs = [ + "test_string_integer_map.cpp", + ], + deps = [ + "//executorch/extension/llm/tokenizer:tiktoken", + ], + env = { + "RESOURCES_PATH": "$(location :resources)/resources", + }, + platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. + ) + runtime.cxx_test( name = "test_bpe_tokenizer", srcs = [ @@ -59,3 +73,8 @@ def define_common_targets(): "resources/**", ]), ) + + runtime.export_file( + name = "test_tiktoken_tokenizer_model", + src = "resources/test_tiktoken_tokenizer.model", + ) diff --git a/extension/llm/tokenizer/test/test_string_integer_map.cpp b/extension/llm/tokenizer/test/test_string_integer_map.cpp new file mode 100644 index 00000000000..24a9853429d --- /dev/null +++ b/extension/llm/tokenizer/test/test_string_integer_map.cpp @@ -0,0 +1,318 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef EXECUTORCH_FB_BUCK +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) || defined(WIN32) || defined(__linux__) +#define TEST_MEMORY_COMPARISON 1 + +#if defined(__APPLE__) +#include +#else +#include +#endif +#endif + +using namespace ::testing; +using ::executorch::extension::llm::StringIntegerMap; +using ::executorch::extension::llm::StringIntegerMapTypeBuilder; +using ::executorch::extension::llm::base64::decode; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; +using TokenizerMap = std::unordered_map; + +class StringIntegerMapTest : public Test { + public: + void SetUp() override { + ::executorch::runtime::runtime_init(); +#ifdef EXECUTORCH_FB_BUCK + modelPath_ = facebook::xplat::testing::getPathForTestResource( + "resources/test_tiktoken_tokenizer.model"); +#else + modelPath_ = std::getenv("RESOURCES_PATH") + + std::string("/test_tiktoken_tokenizer.model"); +#endif + } + + Result loadModel() { + std::ifstream file(modelPath_); + ET_CHECK_OR_RETURN_ERROR( + file, + InvalidArgument, + "failed to open encoder file: %s", + modelPath_.c_str()); + + TokenizerMap model; + for (std::string line; std::getline(file, line);) { + if (line.empty()) { + continue; + } + + auto pos = line.find(' '); + auto token = ET_UNWRAP(decode({line.data(), pos})); + uint64_t rank = 0; + try { + rank = std::stoul(line.substr(pos + 1)); + } catch (const std::exception&) { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidArgument, "invalid encoder rank: %s", line.c_str()); + } + model[token] = rank; + } + + return model; + } + + std::string modelPath_; +}; + +#if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON + +class TrackingAllocatorBase { + public: + static void reset(); + static std::size_t getSize(); + + protected: + static void* allocate(std::size_t size); + static void deallocate(void* ptr); + + static std::size_t size_; +}; + +void TrackingAllocatorBase::reset() { + size_ = 0; +} + +std::size_t TrackingAllocatorBase::getSize() { + return size_; +} + +void* TrackingAllocatorBase::allocate(std::size_t size) { + void* ptr = malloc(size); + if (!ptr) { + return nullptr; + } + +#if defined(WIN32) + size_ += _msize(ptr); +#elif defined(__APPLE__) + size_ += malloc_size(const_cast(ptr)); +#else + size_ += malloc_usable_size(ptr); +#endif + + return ptr; +} + +void TrackingAllocatorBase::deallocate(void* ptr) { + if (!ptr) { + return; + } + +#if defined(WIN32) + size_ -= _msize(ptr); +#elif defined(__APPLE__) + size_ -= malloc_size(const_cast(ptr)); +#else + size_ -= malloc_usable_size(ptr); +#endif + + free(ptr); +} + +std::size_t TrackingAllocatorBase::size_ = 0; + +template +class TrackingAllocator : public TrackingAllocatorBase { + public: + using value_type = T; + TrackingAllocator() noexcept = default; + template + explicit TrackingAllocator(TrackingAllocator const&) noexcept {} + + value_type* allocate(std::size_t count) { + return static_cast( + TrackingAllocatorBase::allocate(count * sizeof(value_type))); // NOLINT + } + + void deallocate(value_type* ptr, std::size_t /*count*/) noexcept { + TrackingAllocatorBase::deallocate(ptr); + } +}; + +template +bool operator==( + TrackingAllocator const&, + TrackingAllocator const&) noexcept { + return true; +} + +template +bool operator!=( + TrackingAllocator const& lhs, + TrackingAllocator const& rhs) noexcept { + return !(lhs == rhs); +} + +#endif + +TEST_F(StringIntegerMapTest, CreateFromModel) { + const auto res = loadModel(); + ASSERT_EQ(res.ok(), true); + const auto& model = res.get(); + StringIntegerMap map(model); + + for (const auto& [model_key, model_value] : model) { + EXPECT_THAT(map.tryGetInteger(model_key), testing::Optional(model_value)) + << model_key; + EXPECT_THAT(map.tryGetString(model_value), testing::Optional(model_key)) + << model_value; + } + + EXPECT_FALSE(map.tryGetInteger("Ich weiß nicht")); + EXPECT_FALSE(map.tryGetString(999999999)); +} + +#if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON + +TEST_F(StringIntegerMapTest, MemoryConsumptionComparison) { + TrackingAllocatorBase::reset(); + EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); + + const auto res = loadModel(); + ASSERT_EQ(res.ok(), true); + const auto& model = res.get(); + + std::size_t string_integer_map_size = 0; + { + typename StringIntegerMapTypeBuilder<>::WithAllocator< + TrackingAllocator>::Map map(model); + string_integer_map_size = TrackingAllocatorBase::getSize(); + } + + EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); + + std::size_t unordered_map_size = 0; + { + std::unordered_map< + std::string, + std::uint64_t, + std::hash, + std::equal_to, + TrackingAllocator>> + strings_to_ints; + std::unordered_map< + std::uint64_t, + std::string, + std::hash, + std::equal_to, + TrackingAllocator>> + ints_to_strings; + for (const auto& [k, v] : model) { + strings_to_ints.emplace(k, v); + ints_to_strings.emplace(v, k); + } + + unordered_map_size = TrackingAllocatorBase::getSize(); + } + + EXPECT_LT(string_integer_map_size, unordered_map_size); + +#if 1 + std::cout << "string integer map size = " << string_integer_map_size + << std::endl; + std::cout << "unordered map size = " << unordered_map_size << std::endl; +#endif +} + +#endif + +template +struct FixedHash { + std::size_t operator()(const std::string_view& str) const { + if (str.empty()) { + return hash_offset; + } + + return str.size() - 1 + hash_offset; + } + + std::size_t operator()(std::uint64_t value) const { + if (value == 0) { + return hash_offset; + } + + return static_cast(std::log10(value)) + hash_offset; + } +}; + +template +class StringIntegerMapHashTest : public Test { + public: + using Container = typename StringIntegerMapTypeBuilder<>::WithIntegerHash< + THash>::template WithStringHash::Map; +}; + +using StringIntegerMapHashTestTypes = + ::testing::Types, FixedHash<1>, FixedHash<2>, FixedHash<3>>; +TYPED_TEST_SUITE(StringIntegerMapHashTest, StringIntegerMapHashTestTypes); + +TYPED_TEST(StringIntegerMapHashTest, HashCollisions) { + std::unordered_map source = { + {"a", 0}, + {"b", 1}, + {"c", 2}, + {"d", 3}, + }; + + typename TestFixture::Container map(source); + + // + // Check that the strings exist in the map. + // + + EXPECT_THAT(map.tryGetInteger("a"), Optional(0ull)); + EXPECT_THAT(map.tryGetInteger("b"), Optional(1ull)); + EXPECT_THAT(map.tryGetInteger("c"), Optional(2ull)); + EXPECT_THAT(map.tryGetInteger("d"), Optional(3ull)); + + EXPECT_FALSE(map.tryGetInteger("e")); + + // + // Check that the integers exist in the map. + // + + EXPECT_THAT(map.tryGetString(0), Optional(std::string_view("a"))); + EXPECT_THAT(map.tryGetString(1), Optional(std::string_view("b"))); + EXPECT_THAT(map.tryGetString(2), Optional(std::string_view("c"))); + EXPECT_THAT(map.tryGetString(3), Optional(std::string_view("d"))); + + EXPECT_FALSE(map.tryGetString(4)); + + // + // Test a lookup into the next bucket (which should be empty). + // + + EXPECT_FALSE(map.tryGetInteger("aa")); + EXPECT_FALSE(map.tryGetInteger("aaa")); + EXPECT_FALSE(map.tryGetInteger("aaaa")); + + EXPECT_FALSE(map.tryGetString(10)); + EXPECT_FALSE(map.tryGetString(100)); + EXPECT_FALSE(map.tryGetString(1000)); +} diff --git a/extension/llm/tokenizer/tiktoken.cpp b/extension/llm/tokenizer/tiktoken.cpp index f99ac2e955e..725a3fe453d 100644 --- a/extension/llm/tokenizer/tiktoken.cpp +++ b/extension/llm/tokenizer/tiktoken.cpp @@ -29,7 +29,9 @@ #include #include #include +#include #include +#include using ::executorch::runtime::Error; using ::executorch::runtime::Result; @@ -97,6 +99,7 @@ static Result _load_encoder(const std::string& path) { Encoder encoder; std::string line; + std::unordered_set ranks; while (std::getline(file, line)) { auto [token, rank] = ET_UNWRAP(_parse(line)); @@ -105,28 +108,20 @@ static Result _load_encoder(const std::string& path) { InvalidArgument, "duplicate item: %s", line.c_str()); - } - return encoder; -} - -static Result _build_decoder(const Encoder& encoder) { - Decoder decoder; - for (const auto& [k, v] : encoder) { - decoder.emplace(v, k); + ET_CHECK_OR_RETURN_ERROR( + ranks.insert(rank).second, + InvalidArgument, + "duplicate rank: %s", + line.c_str()); } - ET_CHECK_OR_RETURN_ERROR( - encoder.size() == decoder.size(), - InvalidArgument, - "duplicate items in encoder"); - - return decoder; + return encoder; } static std::vector _byte_pair_merge( const std::string& piece, - const std::unordered_map& ranks, + const StringIntegerMap<>& ranks, std::function func) { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. @@ -145,10 +140,7 @@ static std::vector _byte_pair_merge( auto s = parts[start_idx].first; auto e = parts[start_idx + skip + 2].first; auto key = piece.substr(s, e - s); - auto iter = ranks.find(key); - if (iter != ranks.end()) { - return iter->second; - } + return ranks.tryGetInteger(key); } return std::nullopt; }; @@ -230,11 +222,11 @@ static std::vector _byte_pair_merge( static std::vector _byte_pair_encode( const std::string& piece, - const Encoder& encoder) { + const StringIntegerMap<>& tokenizer) { if (piece.size() == 1) { - auto iter = encoder.find(piece); - if (iter != encoder.end()) { - return std::vector({iter->second}); + const auto result = tokenizer.tryGetInteger(piece); + if (result) { + return std::vector(*result); } else { // TODO: is it possible? return {}; @@ -242,11 +234,11 @@ static std::vector _byte_pair_encode( } return _byte_pair_merge( - piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) { + piece, tokenizer, [&piece, &tokenizer](uint64_t start, uint64_t stop) { std::string key = piece.substr(start, stop - start); - auto iter = encoder.find(key); - if (iter != encoder.end()) { - return iter->second; + const auto result = tokenizer.tryGetInteger(key); + if (result) { + return *result; } else { // TODO: what if key does not exist? Should we return `unknown`? // assert(false); // ?? @@ -278,7 +270,7 @@ Tiktoken::_split_with_allowed_special_token( break; } - if (allowed_special.count(special) == 1) { + if (allowed_special.tryGetInteger(special).has_value()) { // Found an allowed special token, split the text with it. #if __cplusplus >= 202002L return std::make_pair( @@ -302,13 +294,13 @@ void Tiktoken::_encode( std::string piece; assert(_regex); while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { - auto iter = _encoder.find(piece); - if (iter != _encoder.end()) { + const auto result = _token_map->tryGetInteger(piece); + if (result) { last_piece_token_len = 1; - ret.push_back(iter->second); + ret.push_back(*result); continue; } - auto tokens = _byte_pair_encode(piece, _encoder); + auto tokens = _byte_pair_encode(piece, *_token_map); last_piece_token_len = tokens.size(); ret.insert(ret.end(), tokens.begin(), tokens.end()); } @@ -328,16 +320,14 @@ std::pair, uint64_t> Tiktoken::_encode_with_special_token( _encode(sub_input, tokens, last_piece_token_len); if (special) { - uint64_t token = 0; - try { - token = _special_token_encoder.at(*special); - } catch (const std::out_of_range&) { + const auto result = _special_token_map->tryGetInteger(*special); + if (!result) { // Should never go here, since special pattern includes all special // chars. ET_CHECK_MSG(false, "unknown special token: %s", special->c_str()); } - tokens.push_back(token); + tokens.push_back(*result); last_piece_token_len = 0; } else { break; @@ -380,11 +370,10 @@ Tiktoken::Tiktoken( } Error Tiktoken::load(const std::string& path) { - _encoder = ET_UNWRAP(_load_encoder(path)); - _special_token_encoder = _build_special_token_encoder(_encoder.size()); - - _decoder = ET_UNWRAP(_build_decoder(_encoder)); - _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder)); + auto encoder = ET_UNWRAP(_load_encoder(path)); + _token_map.emplace(StringIntegerMap<>(encoder)); + auto special_token_encoder = _build_special_token_encoder(encoder.size()); + _special_token_map.emplace(StringIntegerMap<>(special_token_encoder)); _regex = _create_regex(_pattern); // Warmup re2 as it is slow on the first run, void the return value as it's @@ -392,14 +381,14 @@ Error Tiktoken::load(const std::string& path) { // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 (void)_regex->ReverseProgramSize(); - _special_token_regex = _build_special_token_regex(_special_token_encoder); + _special_token_regex = _build_special_token_regex(special_token_encoder); // Same as above, warm up re2 (void)_special_token_regex->ReverseProgramSize(); // initialize vocab_size, bos_tok, eos_tok - vocab_size_ = _encoder.size() + _special_token_encoder.size(); - bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index)); - eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index)); + vocab_size_ = encoder.size() + special_token_encoder.size(); + bos_tok_ = special_token_encoder.at(_special_tokens->at(_bos_token_index)); + eos_tok_ = special_token_encoder.at(_special_tokens->at(_eos_token_index)); initialized_ = true; return Error::Ok; @@ -410,7 +399,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const { if (!initialized_) { return Error::NotSupported; } - auto res = _encode_with_special_token(text, _special_token_encoder).first; + auto res = _encode_with_special_token(text, *_special_token_map).first; for (auto i = 0; i < bos; ++i) { res.insert(res.begin(), bos_tok_); } @@ -425,21 +414,19 @@ Result Tiktoken::decode(uint64_t prev, uint64_t cur) const { ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur)); std::string ret; - std::string token_bytes; - auto iter = _decoder.find(cur); - if (iter != _decoder.end()) { - token_bytes = iter->second; - } else { - iter = _special_token_decoder.find(cur); - if (iter != _special_token_decoder.end()) { - token_bytes = iter->second; - } else { + std::string_view token_bytes; + auto result = _token_map->tryGetString(cur); + if (!result) { + result = _special_token_map->tryGetString(cur); + if (!result) { ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur); + } else { + token_bytes = *result; } + } else { + token_bytes = *result; } - ret += token_bytes; - - return ret; + return std::string(token_bytes); } // -------------------------public method end------------------------------- diff --git a/extension/llm/tokenizer/tiktoken.h b/extension/llm/tokenizer/tiktoken.h index 5201c07a184..d7d93b27597 100644 --- a/extension/llm/tokenizer/tiktoken.h +++ b/extension/llm/tokenizer/tiktoken.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -68,10 +69,8 @@ class ET_EXPERIMENTAL Tiktoken : public Tokenizer { // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. const std::string _pattern = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; - Encoder _encoder; - Encoder _special_token_encoder; - Decoder _decoder; - Decoder _special_token_decoder; + std::optional> _token_map; + std::optional> _special_token_map; Re2UPtr _regex; Re2UPtr _special_token_regex;