1414#include < memory>
1515#include < optional>
1616#include < string>
17+ #include < type_traits>
1718#include < unordered_map>
1819#include < vector>
1920
2021// Third Party
2122#include < re2/re2.h>
2223
2324// Local
25+ #include < pytorch/tokenizers/error.h>
2426#include < pytorch/tokenizers/result.h>
27+ #include < pytorch/tokenizers/string_integer_map.h>
2528#include < pytorch/tokenizers/tokenizer.h>
2629
2730namespace tokenizers {
2831namespace detail {
2932
30- using Encoder = std::unordered_map<std::string, uint64_t >;
31- using Decoder = std::unordered_map<uint64_t , std::string>;
3233using Re2UPtr = std::unique_ptr<re2::RE2>;
34+ using TokenMap = StringIntegerMap<>;
35+
36+ template <typename TToken, typename TRank>
37+ static Result<TokenMap> buildTokenMap (
38+ std::vector<std::pair<TToken, TRank>> container) {
39+ static_assert (
40+ std::is_same_v<TToken, std::string> ||
41+ std::is_same_v<TToken, std::string_view>,
42+ " TToken must be std::string or std::string_view" );
43+ static_assert (
44+ std::is_integral_v<TRank> && std::is_unsigned_v<TRank>,
45+ " TRank must be an unsigned integer" );
46+
47+ std::sort (
48+ container.begin (), container.end (), [](const auto & a, const auto & b) {
49+ return a.first < b.first ;
50+ });
51+
52+ auto duplicate_begin = std::unique (
53+ container.begin (), container.end (), [](const auto & a, const auto & b) {
54+ return a.first == b.first ;
55+ });
56+
57+ TK_CHECK_OR_RETURN_ERROR (
58+ duplicate_begin == container.end (),
59+ ParseFailure,
60+ " duplicate token: %s rank: %llu" ,
61+ duplicate_begin->first .c_str (),
62+ static_cast <unsigned long long >(duplicate_begin->second ));
63+
64+ std::sort (
65+ container.begin (), container.end (), [](const auto & a, const auto & b) {
66+ return a.second < b.second ;
67+ });
68+
69+ duplicate_begin = std::unique (
70+ container.begin (), container.end (), [](const auto & a, const auto & b) {
71+ return a.second == b.second ;
72+ });
73+
74+ TK_CHECK_OR_RETURN_ERROR (
75+ duplicate_begin == container.end (),
76+ ParseFailure,
77+ " duplicate rank: %llu"
78+ " token: %s" ,
79+ static_cast <unsigned long long >(duplicate_begin->second ),
80+ duplicate_begin->first .c_str ());
81+
82+ return TokenMap (container);
83+ };
84+
85+ template <typename TContainer, typename TTokenAccessor, typename TRankAccessor>
86+ static Result<TokenMap> buildTokenMap (
87+ const TContainer& container,
88+ TTokenAccessor token_accessor,
89+ TRankAccessor rank_accessor) {
90+ using TokenType = std::invoke_result_t <TTokenAccessor, const TContainer&>;
91+ using RankType = std::invoke_result_t <TRankAccessor, const TContainer&>;
92+
93+ static_assert (
94+ std::is_same_v<TokenType, std::string> ||
95+ std::is_same_v<TokenType, std::string_view>,
96+ " TokenType must be std::string or std::string_view" );
97+ static_assert (
98+ std::is_integral_v<RankType> && std::is_unsigned_v<RankType>,
99+ " RankType must be an unsigned integer" );
100+
101+ std::vector<std::pair<TokenType, RankType>> pairs;
102+ pairs.reserve (container.size ());
103+ for (const auto & value : container) {
104+ pairs.emplace_back (token_accessor (value), rank_accessor (value));
105+ }
106+
107+ return buildTokenMap (std::move (pairs));
108+ }
33109
34110class BPETokenizerBase : public Tokenizer {
35111 public:
@@ -46,22 +122,20 @@ class BPETokenizerBase : public Tokenizer {
46122 std::pair<std::optional<std::string>, re2::StringPiece>
47123 split_with_allowed_special_token_ (
48124 re2::StringPiece& input,
49- const Encoder & allowed_special) const ;
125+ const TokenMap & allowed_special) const ;
50126
51127 Result<std::pair<std::vector<uint64_t >, uint64_t >> encode_with_special_token_ (
52128 const std::string& text,
53- const Encoder & allowed_special) const ;
129+ const TokenMap & allowed_special) const ;
54130
55131 Result<std::vector<uint64_t >> byte_pair_encode_ (
56132 const std::string& piece,
57- const Encoder & encoder) const ;
133+ const TokenMap & encoder) const ;
58134
59135 // Protected members that can be overloaded by other BPE tokenizers
60136 Re2UPtr special_token_regex_;
61- Encoder encoder_;
62- Encoder special_token_encoder_;
63- Decoder decoder_;
64- Decoder special_token_decoder_;
137+ std::optional<TokenMap> token_map_;
138+ std::optional<TokenMap> special_token_map_;
65139
66140 private:
67141 virtual Error _encode (
0 commit comments