|
9 | 9 | #include <nlohmann/json.hpp> |
10 | 10 |
|
11 | 11 | #if defined(_WIN32) |
| 12 | +# ifndef NOMINMAX |
| 13 | +# define NOMINMAX |
| 14 | +# endif |
12 | 15 | # include <windows.h> |
13 | 16 | # include <io.h> |
14 | 17 | #else |
@@ -940,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama |
940 | 943 | static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, |
941 | 944 | std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) { |
942 | 945 | const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; |
943 | | - |
944 | | - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); |
945 | | - prompt_tokens.resize(n_prompt_tokens); |
946 | | - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, |
947 | | - true) < 0) { |
948 | | - printe("failed to tokenize the prompt\n"); |
| 946 | + int n_tokens = prompt.size() + 2 * is_first; |
| 947 | + prompt_tokens.resize(n_tokens); |
| 948 | + n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), |
| 949 | + prompt_tokens.data(), prompt_tokens.size(), |
| 950 | + is_first, /*parse_special =*/true); |
| 951 | + if (n_tokens == std::numeric_limits<int32_t>::min()) { |
| 952 | + printe("tokenization failed: input too large\n"); |
949 | 953 | return -1; |
950 | 954 | } |
951 | | - |
952 | | - return n_prompt_tokens; |
| 955 | + if (n_tokens < 0) { |
| 956 | + prompt_tokens.resize(-n_tokens); |
| 957 | + int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), |
| 958 | + prompt_tokens.data(), prompt_tokens.size(), |
| 959 | + is_first, /*parse_special =*/true); |
| 960 | + if (check != -n_tokens) { |
| 961 | + printe("failed to tokenize the prompt (size mismatch)\n"); |
| 962 | + return -1; |
| 963 | + } |
| 964 | + n_tokens = check; |
| 965 | + } else { |
| 966 | + prompt_tokens.resize(n_tokens); |
| 967 | + } |
| 968 | + return n_tokens; |
953 | 969 | } |
954 | 970 |
|
955 | 971 | // Check if we have enough space in the context to evaluate this batch |
|
0 commit comments