Skip to content

Commit 432ff23

Browse files
committed
Example for show probability of next token
1 parent 407c237 commit 432ff23

File tree

4 files changed

+290
-1
lines changed

4 files changed

+290
-1
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ else()
1717
add_subdirectory(batched)
1818
add_subdirectory(embedding)
1919
add_subdirectory(eval-callback)
20-
2120
add_subdirectory(gguf-hash)
2221
add_subdirectory(gguf)
2322
add_subdirectory(gritlm)
@@ -35,6 +34,7 @@ else()
3534
add_subdirectory(training)
3635
add_subdirectory(diffusion)
3736
add_subdirectory(model-conversion)
37+
add_subdirectory(prediction-next-token)
3838
if (NOT GGML_BACKEND_DL)
3939
add_subdirectory(convert-llama2c-to-ggml)
4040
# these examples use the backends directly and cannot be built with dynamic loading
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET prediction-next-token)
2+
add_executable(${TARGET} main.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# llama.cpp/examples/prediction-next-token
2+
3+
This directory contains examples demonstrating **next-token prediction** using LLaMA models through [llama.cpp/GGML](https://github.com/ggml-org/llama.cpp).
4+
5+
The tool can be useful for checking and measuring fine tuning results on examples
6+
(Now only on CPU)
7+
---
8+
9+
## Usage
10+
11+
```
12+
prediction-next-token --model <model_path> --prompt <prompt> [--hypothesis <first_word>]
13+
```
14+
15+
or short form:
16+
17+
```
18+
prediction-next-token -m <model_path> -p <prompt> [-h <first_word>]
19+
```
20+
21+
**Example:**
22+
23+
```bash
24+
prediction-next-token -m "models\llama-3.2-1B-q4_k_m-128k.gguf" -p "Who invented E=mc^2?" -h "Einstein"
25+
```
26+
27+
---
28+
29+
### Notes for non-English UTF-8 text (e.g., Russian)
30+
31+
On **Windows**, it is recommended to use **Windows Terminal**:
32+
33+
```
34+
.\prediction-next-token.exe -m "models\llama-3.2-1B-q4_k_m-128k-ru.gguf" -p "Здравствуйте!" -h "Привет"
35+
chcp 65001
36+
```
37+
38+
* This ensures correct handling of UTF-8 characters both for input arguments and output in the console.
39+
40+
41+
---
42+
43+
## Notes on Model Behavior
44+
45+
* The `--hypothesis` argument is optional and specifies expected/necessary the first word to evaluate.
46+
* After fine-tuning on a dataset, the **perplexity** of the model on a test set should decrease over training epochs.
47+
48+
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#include <llama.h>
2+
#include <windows.h>
3+
#include <algorithm>
4+
#include <cmath>
5+
#include <cstdio>
6+
#include <cstring>
7+
#include <vector>
8+
#include <stdio.h>
9+
#include <string.h>
10+
#include <string>
11+
12+
struct TokenInfo {
13+
int id;
14+
float p;
15+
std::string piece;
16+
};
17+
18+
#include <windows.h>
19+
20+
#include <cstdlib> // для malloc/free
21+
#include <cstring> // для strlen
22+
23+
const char * Utf8FromUtf16(const wchar_t * wstr) {
24+
if (!wstr) {
25+
return nullptr;
26+
}
27+
28+
int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr);
29+
30+
char * buffer = (char *) malloc(size_needed);
31+
if (!buffer) {
32+
return nullptr;
33+
}
34+
35+
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, buffer, size_needed, nullptr, nullptr);
36+
37+
return buffer; // caller должен вызвать free()
38+
}
39+
40+
int wmain(int argc, wchar_t * argv[]) {
41+
SetConsoleOutputCP(CP_UTF8);
42+
SetConsoleCP(CP_UTF8);
43+
// Установка значений по умолчанию
44+
const char * model_path = nullptr;
45+
const char * prompt = nullptr;
46+
const char * word = nullptr;
47+
48+
// Разбор аргументов
49+
for (int i = 1; i < argc; i++) {
50+
if ((wcscmp(argv[i], L"-m") == 0 || wcscmp(argv[i], L"--model") == 0) && i + 1 < argc) {
51+
model_path = Utf8FromUtf16(argv[++i]);
52+
} else if ((wcscmp(argv[i], L"-p") == 0 || wcscmp(argv[i], L"--prompt") == 0) && i + 1 < argc) {
53+
prompt = Utf8FromUtf16(argv[++i]);
54+
} else if ((wcscmp(argv[i], L"-h") == 0 || wcscmp(argv[i], L"--hypothesis") == 0) && i + 1 < argc) {
55+
word = Utf8FromUtf16(argv[++i]);
56+
} else if (i == 1 && argv[i][0] != L'-') {
57+
model_path = Utf8FromUtf16(argv[i]);
58+
if (i + 1 < argc) {
59+
prompt = Utf8FromUtf16(argv[++i]);
60+
}
61+
}
62+
}
63+
64+
// Проверка обязательных аргументов
65+
if (model_path == nullptr || prompt == nullptr) {
66+
fprintf(stderr,
67+
"Usage: %s -m or --model <model_path> -p|--prompt <prompt> [-h|--hypothesis <first_word>]\n",
68+
Utf8FromUtf16(argv[0]));
69+
return 1;
70+
}
71+
72+
// 0) backend
73+
llama_backend_init();
74+
75+
// 1) load model
76+
llama_model_params model_params = llama_model_default_params();
77+
llama_model * model = llama_model_load_from_file(model_path, model_params);
78+
if (!model) {
79+
fprintf(stderr, "failed to load model: %s\n", model_path);
80+
llama_backend_free();
81+
return 1;
82+
}
83+
84+
// 2) context
85+
llama_context_params ctx_params = llama_context_default_params();
86+
ctx_params.n_ctx = 512;
87+
llama_context * ctx = llama_init_from_model(model, ctx_params);
88+
if (!ctx) {
89+
fprintf(stderr, "failed to create context\n");
90+
llama_model_free(model);
91+
llama_backend_free();
92+
return 1;
93+
}
94+
95+
// 3) vocab
96+
const llama_vocab * vocab = llama_model_get_vocab(model);
97+
98+
// 4) tokenize full prompt
99+
int max_tokens = 256;
100+
std::vector<llama_token> tok(max_tokens);
101+
102+
int n_tok = llama_tokenize(vocab,
103+
prompt,
104+
(int) strlen(prompt),
105+
tok.data(),
106+
(int) tok.size(),
107+
/*add_bos=*/true,
108+
/*special=*/true);
109+
if (n_tok < 0) {
110+
max_tokens = -n_tok;
111+
tok.resize(max_tokens);
112+
n_tok = llama_tokenize(vocab, prompt, (int) strlen(prompt), tok.data(), (int) tok.size(), true, true);
113+
}
114+
if (n_tok <= 0) {
115+
fprintf(stderr, "tokenization failed\n");
116+
llama_free(ctx);
117+
llama_model_free(model);
118+
llama_backend_free();
119+
return 1;
120+
}
121+
tok.resize(n_tok);
122+
123+
// 5) build batch correctly (НЕ аллоцируем seq_id вручную!)
124+
llama_batch batch = llama_batch_get_one(tok.data(), (int) tok.size());
125+
// batch.pos / batch.seq_id / batch.n_seq_id / batch.logits = nullptr
126+
// рантайм сам подставит корректные значения и вернёт логиты для последнего токена
127+
128+
// 6) decode
129+
int ret = llama_decode(ctx, batch);
130+
if (ret != 0) {
131+
fprintf(stderr, "llama_decode failed, ret = %d\n", ret);
132+
llama_free(ctx);
133+
llama_model_free(model);
134+
llama_backend_free();
135+
return 1;
136+
}
137+
138+
// 7) logits of the last token in the batch
139+
// (так безопаснее: это "последние" логиты, соответствующие отмеченному последнему токену)
140+
const float * logits = llama_get_logits(ctx);
141+
const int n_vocab = llama_vocab_n_tokens(vocab);
142+
143+
// 8) softmax + top-10
144+
// найдём максимум для стабильного softmax
145+
float max_logit = logits[0];
146+
for (int i = 1; i < n_vocab; ++i) {
147+
if (logits[i] > max_logit) {
148+
max_logit = logits[i];
149+
}
150+
}
151+
// вычислим exp и сумму
152+
std::vector<float> probs(n_vocab);
153+
double sum = 0.0;
154+
for (int i = 0; i < n_vocab; ++i) {
155+
float e = std::exp(logits[i] - max_logit);
156+
probs[i] = e;
157+
sum += e;
158+
}
159+
for (int i = 0; i < n_vocab; ++i) {
160+
probs[i] = (float) (probs[i] / sum);
161+
}
162+
163+
// соберём индексы и отсортируем по вероятности
164+
std::vector<int> ids(n_vocab);
165+
for (int i = 0; i < n_vocab; ++i) {
166+
ids[i] = i;
167+
}
168+
std::partial_sort(ids.begin(), ids.begin() + 10, ids.end(), [&](int a, int b) { return probs[a] > probs[b]; });
169+
170+
// 9) распечатаем top-10
171+
char piece[256];
172+
for (int r = 0; r < 10; ++r) {
173+
int id = ids[r];
174+
int n = llama_token_to_piece(vocab,
175+
id,
176+
piece,
177+
sizeof(piece),
178+
/*special=*/true,
179+
/*clean=*/true);
180+
if (n < 0) {
181+
snprintf(piece, sizeof(piece), "<tok %d>", id);
182+
} else {
183+
piece[n] = '\0';
184+
}
185+
printf("%2d) id=%6d p=%.6f \"%s\"\n", r + 1, id, probs[id], piece);
186+
}
187+
188+
if (word != nullptr) {
189+
// 10) распечатаем ещё интересующие токены
190+
std::vector<TokenInfo> tokens_info;
191+
192+
// Получаем все префиксы строки
193+
std::vector<std::string> prefixes;
194+
size_t text_len = strlen(word);
195+
for (size_t len = 1; len <= text_len; len++) {
196+
char buf[256];
197+
memcpy(buf, word, len);
198+
buf[len] = '\0';
199+
prefixes.push_back(buf);
200+
}
201+
202+
// Проходим по словарю и ищем все токены, которые совпадают с префиксами
203+
for (int id = 0; id < llama_vocab_n_tokens(vocab); ++id) {
204+
char piece[256];
205+
int n = llama_token_to_piece(vocab, id, piece, sizeof(piece), true, true);
206+
if (n <= 0) {
207+
continue;
208+
}
209+
piece[n] = '\0';
210+
211+
// проверка на совпадение с префиксом
212+
for (const auto & pref : prefixes) {
213+
if (strcmp(piece, pref.c_str()) == 0) {
214+
tokens_info.push_back({ id, probs[id], piece });
215+
}
216+
}
217+
}
218+
219+
// Сортируем по убыванию вероятности
220+
std::sort(
221+
tokens_info.begin(), tokens_info.end(), [](const TokenInfo & a, const TokenInfo & b) { return a.p > b.p; });
222+
223+
// Вывод
224+
for (const auto & t : tokens_info) {
225+
if (t.p > 0.00000049f) {
226+
printf("id=%6d p=%.6f \"%s\"\n", t.id, t.p, t.piece.c_str());
227+
}
228+
}
229+
}
230+
231+
// 11) cleanup
232+
llama_free(ctx);
233+
llama_model_free(model);
234+
llama_backend_free();
235+
return 0;
236+
}

0 commit comments

Comments
 (0)