88#include < fstream>
99#include < thread>
1010
11- void llama_ngram_cache_update (llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
11+ void common_ngram_cache_update (common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
1212 std::vector<llama_token> & inp, int nnew, bool print_progress) {
1313 const int64_t t_start_ms = ggml_time_ms ();
1414 const int64_t inp_size = inp.size ();
@@ -20,16 +20,16 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
2020 const int64_t i_start = std::max (inp_size - nnew, ngram_size);
2121 for (int64_t i = i_start; i < inp_size; ++i) {
2222 const int64_t ngram_start = i - ngram_size;
23- llama_ngram ngram (&inp[ngram_start], ngram_size);
23+ common_ngram ngram (&inp[ngram_start], ngram_size);
2424 const llama_token token = inp[i];
2525
26- llama_ngram_cache ::iterator part_it = ngram_cache.find (ngram);
26+ common_ngram_cache ::iterator part_it = ngram_cache.find (ngram);
2727 if (part_it == ngram_cache.end ()) {
28- llama_ngram_cache_part part;
28+ common_ngram_cache_part part;
2929 part.emplace (token, 1 );
3030 ngram_cache.emplace (ngram, part);
3131 } else {
32- llama_ngram_cache_part ::iterator token_count_it = part_it->second .find (token);
32+ common_ngram_cache_part ::iterator token_count_it = part_it->second .find (token);
3333 if (token_count_it == part_it->second .end ()) {
3434 part_it->second .emplace (token, 1 );
3535 } else {
@@ -62,12 +62,12 @@ constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
6262constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
6363
6464// Helper function that tries to draft a token from only the static ngram cache:
65- static llama_token try_draft (llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
66- llama_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
65+ static llama_token try_draft (common_ngram_cache & nc_static, const common_ngram ngram_static) {
66+ common_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
6767 if (part_static_it == nc_static.end ()) {
6868 return -1 ;
6969 }
70- const llama_ngram_cache_part part_static = part_static_it->second ;
70+ const common_ngram_cache_part part_static = part_static_it->second ;
7171
7272 int max_count_static = 0 ;
7373 int sum_count_static = 0 ;
@@ -95,19 +95,19 @@ static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ng
9595
9696// Try to draft a token from primary cache (context/dynamic), validate with static cache:
9797static llama_token try_draft (
98- llama_ngram_cache & nc_primary, const std::vector<llama_ngram > & ngrams_primary, llama_ngram_cache_part & part_static,
98+ common_ngram_cache & nc_primary, const std::vector<common_ngram > & ngrams_primary, common_ngram_cache_part & part_static,
9999 const int * min_sample_size, const int * min_percent) {
100100
101101 llama_token drafted_token = -1 ;
102102
103103 for (int i = ngrams_primary.size ()-1 ; i >= 0 && drafted_token == -1 ; --i) {
104- const llama_ngram ngram_primary = ngrams_primary[i];
104+ const common_ngram ngram_primary = ngrams_primary[i];
105105
106- llama_ngram_cache ::iterator part_primary_it = nc_primary.find (ngram_primary);
106+ common_ngram_cache ::iterator part_primary_it = nc_primary.find (ngram_primary);
107107 if (part_primary_it == nc_primary.end ()) {
108108 continue ;
109109 }
110- const llama_ngram_cache_part part_primary = part_primary_it->second ;
110+ const common_ngram_cache_part part_primary = part_primary_it->second ;
111111
112112 int max_count_primary = 0 ;
113113 int max_count_static = 0 ;
@@ -117,7 +117,7 @@ static llama_token try_draft(
117117 for (std::pair<llama_token, int > token_count_primary : part_primary) {
118118 const llama_token token = token_count_primary.first ;
119119
120- llama_ngram_cache_part ::iterator token_count_static_it = part_static.find (token);
120+ common_ngram_cache_part ::iterator token_count_static_it = part_static.find (token);
121121
122122 const int32_t count_primary = token_count_primary.second ;
123123 const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
@@ -142,9 +142,9 @@ static llama_token try_draft(
142142 return drafted_token;
143143}
144144
145- void llama_ngram_cache_draft (
145+ void common_ngram_cache_draft (
146146 std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
147- llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
147+ common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
148148) {
149149 GGML_ASSERT (draft.size () == 1 );
150150 const int inp_size = inp.size ();
@@ -157,21 +157,21 @@ void llama_ngram_cache_draft(
157157 llama_token drafted_token = -1 ;
158158
159159 const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
160- llama_ngram ngram_static;
160+ common_ngram ngram_static;
161161 for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
162162 ngram_static.tokens [j-ngram_start_static] = get_token (inp, draft, j);
163163 }
164- llama_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
165- llama_ngram_cache_part part_static;
164+ common_ngram_cache ::iterator part_static_it = nc_static.find (ngram_static);
165+ common_ngram_cache_part part_static;
166166 if (part_static_it != nc_static.end ()) {
167167 part_static = part_static_it->second ;
168168 }
169169
170170 // cd = context + dynamic
171- std::vector<llama_ngram > ngrams_cd;
171+ std::vector<common_ngram > ngrams_cd;
172172 for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
173173 const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
174- llama_ngram ngram_cd;
174+ common_ngram ngram_cd;
175175 for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
176176 ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, draft, j);
177177 }
@@ -196,16 +196,16 @@ void llama_ngram_cache_draft(
196196 }
197197}
198198
199- void llama_ngram_cache_save (llama_ngram_cache & ngram_cache, std::string & filename) {
199+ void common_ngram_cache_save (common_ngram_cache & ngram_cache, std::string & filename) {
200200 std::ofstream file_out (filename, std::ios::binary);
201- for (std::pair<llama_ngram, llama_ngram_cache_part > item : ngram_cache) {
202- const llama_ngram ngram = item.first ;
203- llama_ngram_cache_part token_counts = item.second ;
201+ for (std::pair<common_ngram, common_ngram_cache_part > item : ngram_cache) {
202+ const common_ngram ngram = item.first ;
203+ common_ngram_cache_part token_counts = item.second ;
204204 GGML_ASSERT (!token_counts.empty ());
205205 const int32_t ntokens = token_counts.size ();
206206 GGML_ASSERT (ntokens > 0 );
207207
208- file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (llama_ngram ));
208+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (common_ngram ));
209209 file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
210210 for (std::pair<llama_token, int32_t > item2 : token_counts) {
211211 const llama_token token = item2.first ;
@@ -219,14 +219,14 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen
219219
220220}
221221
222- llama_ngram_cache llama_ngram_cache_load (std::string & filename) {
222+ common_ngram_cache common_ngram_cache_load (std::string & filename) {
223223 std::ifstream hashmap_file (filename, std::ios::binary);
224224 if (!hashmap_file) {
225225 throw std::ifstream::failure (" Unable to open file " + filename);
226226 }
227- llama_ngram_cache ngram_cache;
227+ common_ngram_cache ngram_cache;
228228
229- llama_ngram ngram;
229+ common_ngram ngram;
230230 int32_t ntokens;
231231 llama_token token;
232232 int32_t count;
@@ -235,11 +235,11 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
235235 char * ntokensc = reinterpret_cast <char *>(&ntokens);
236236 char * tokenc = reinterpret_cast <char *>(&token);
237237 char * countc = reinterpret_cast <char *>(&count);
238- while (hashmap_file.read (ngramc, sizeof (llama_ngram ))) {
238+ while (hashmap_file.read (ngramc, sizeof (common_ngram ))) {
239239 GGML_ASSERT (!hashmap_file.eof ());
240240 GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
241241 GGML_ASSERT (ntokens > 0 );
242- llama_ngram_cache_part token_counts;
242+ common_ngram_cache_part token_counts;
243243
244244 for (int i = 0 ; i < ntokens; ++i) {
245245 GGML_ASSERT (!hashmap_file.eof ());
@@ -257,12 +257,12 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
257257 return ngram_cache;
258258}
259259
260- void llama_ngram_cache_merge (llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
261- for (std::pair<llama_ngram, llama_ngram_cache_part > ngram_part : ngram_cache_add) {
262- const llama_ngram ngram = ngram_part.first ;
263- llama_ngram_cache_part part = ngram_part.second ;
260+ void common_ngram_cache_merge (common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
261+ for (std::pair<common_ngram, common_ngram_cache_part > ngram_part : ngram_cache_add) {
262+ const common_ngram ngram = ngram_part.first ;
263+ common_ngram_cache_part part = ngram_part.second ;
264264
265- llama_ngram_cache ::iterator part_merged_it = ngram_cache_target.find (ngram);
265+ common_ngram_cache ::iterator part_merged_it = ngram_cache_target.find (ngram);
266266 if (part_merged_it == ngram_cache_target.end ()) {
267267 ngram_cache_target.emplace (ngram, part);
268268 continue ;
@@ -273,7 +273,7 @@ void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram
273273 const int32_t count = token_count.second ;
274274 GGML_ASSERT (count > 0 );
275275
276- llama_ngram_cache_part ::iterator token_count_merged_it = part_merged_it->second .find (token);
276+ common_ngram_cache_part ::iterator token_count_merged_it = part_merged_it->second .find (token);
277277 if (token_count_merged_it == part_merged_it->second .end ()) {
278278 part_merged_it->second .emplace (token, count);
279279 continue ;
0 commit comments