| 
 | 1 | +#include <llama.h>  | 
 | 2 | +#include <windows.h>  | 
 | 3 | +#include <algorithm>  | 
 | 4 | +#include <cmath>  | 
 | 5 | +#include <cstdio>  | 
 | 6 | +#include <cstring>  | 
 | 7 | +#include <vector>  | 
 | 8 | +#include <stdio.h>  | 
 | 9 | +#include <string.h>  | 
 | 10 | +#include <string>  | 
 | 11 | + | 
 | 12 | +struct TokenInfo {  | 
 | 13 | +    int         id;  | 
 | 14 | +    float       p;  | 
 | 15 | +    std::string piece;  | 
 | 16 | +};  | 
 | 17 | + | 
 | 18 | +#include <windows.h>  | 
 | 19 | + | 
 | 20 | +#include <cstdlib>  // для malloc/free  | 
 | 21 | +#include <cstring>  // для strlen  | 
 | 22 | + | 
 | 23 | +const char * Utf8FromUtf16(const wchar_t * wstr) {  | 
 | 24 | +    if (!wstr) {  | 
 | 25 | +        return nullptr;  | 
 | 26 | +    }  | 
 | 27 | + | 
 | 28 | +    int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr);  | 
 | 29 | + | 
 | 30 | +    char * buffer = (char *) malloc(size_needed);  | 
 | 31 | +    if (!buffer) {  | 
 | 32 | +        return nullptr;  | 
 | 33 | +    }  | 
 | 34 | + | 
 | 35 | +    WideCharToMultiByte(CP_UTF8, 0, wstr, -1, buffer, size_needed, nullptr, nullptr);  | 
 | 36 | + | 
 | 37 | +    return buffer;  // caller должен вызвать free()  | 
 | 38 | +}  | 
 | 39 | + | 
 | 40 | +int wmain(int argc, wchar_t * argv[]) {  | 
 | 41 | +    SetConsoleOutputCP(CP_UTF8);  | 
 | 42 | +    SetConsoleCP(CP_UTF8);  | 
 | 43 | +    // Установка значений по умолчанию  | 
 | 44 | +    const char * model_path = nullptr;  | 
 | 45 | +    const char * prompt     = nullptr;  | 
 | 46 | +    const char * word       = nullptr;  | 
 | 47 | + | 
 | 48 | +    // Разбор аргументов  | 
 | 49 | +    for (int i = 1; i < argc; i++) {  | 
 | 50 | +        if ((wcscmp(argv[i], L"-m") == 0 || wcscmp(argv[i], L"--model") == 0) && i + 1 < argc) {  | 
 | 51 | +            model_path = Utf8FromUtf16(argv[++i]);  | 
 | 52 | +        } else if ((wcscmp(argv[i], L"-p") == 0 || wcscmp(argv[i], L"--prompt") == 0) && i + 1 < argc) {  | 
 | 53 | +            prompt = Utf8FromUtf16(argv[++i]);  | 
 | 54 | +        } else if ((wcscmp(argv[i], L"-h") == 0 || wcscmp(argv[i], L"--hypothesis") == 0) && i + 1 < argc) {  | 
 | 55 | +            word = Utf8FromUtf16(argv[++i]);  | 
 | 56 | +        } else if (i == 1 && argv[i][0] != L'-') {  | 
 | 57 | +            model_path = Utf8FromUtf16(argv[i]);  | 
 | 58 | +            if (i + 1 < argc) {  | 
 | 59 | +                prompt = Utf8FromUtf16(argv[++i]);  | 
 | 60 | +            }  | 
 | 61 | +        }  | 
 | 62 | +    }  | 
 | 63 | + | 
 | 64 | +    // Проверка обязательных аргументов  | 
 | 65 | +    if (model_path == nullptr || prompt == nullptr) {  | 
 | 66 | +        fprintf(stderr,  | 
 | 67 | +                "Usage: %s -m or --model <model_path> -p|--prompt <prompt> [-h|--hypothesis <first_word>]\n",  | 
 | 68 | +                Utf8FromUtf16(argv[0]));  | 
 | 69 | +        return 1;  | 
 | 70 | +    }  | 
 | 71 | + | 
 | 72 | +    // 0) backend  | 
 | 73 | +    llama_backend_init();  | 
 | 74 | + | 
 | 75 | +    // 1) load model  | 
 | 76 | +    llama_model_params model_params = llama_model_default_params();  | 
 | 77 | +    llama_model *      model        = llama_model_load_from_file(model_path, model_params);  | 
 | 78 | +    if (!model) {  | 
 | 79 | +        fprintf(stderr, "failed to load model: %s\n", model_path);  | 
 | 80 | +        llama_backend_free();  | 
 | 81 | +        return 1;  | 
 | 82 | +    }  | 
 | 83 | + | 
 | 84 | +    // 2) context  | 
 | 85 | +    llama_context_params ctx_params = llama_context_default_params();  | 
 | 86 | +    ctx_params.n_ctx                = 512;  | 
 | 87 | +    llama_context * ctx             = llama_init_from_model(model, ctx_params);  | 
 | 88 | +    if (!ctx) {  | 
 | 89 | +        fprintf(stderr, "failed to create context\n");  | 
 | 90 | +        llama_model_free(model);  | 
 | 91 | +        llama_backend_free();  | 
 | 92 | +        return 1;  | 
 | 93 | +    }  | 
 | 94 | + | 
 | 95 | +    // 3) vocab  | 
 | 96 | +    const llama_vocab * vocab = llama_model_get_vocab(model);  | 
 | 97 | + | 
 | 98 | +    // 4) tokenize full prompt  | 
 | 99 | +    int                      max_tokens = 256;  | 
 | 100 | +    std::vector<llama_token> tok(max_tokens);  | 
 | 101 | + | 
 | 102 | +    int n_tok = llama_tokenize(vocab,  | 
 | 103 | +                               prompt,  | 
 | 104 | +                               (int) strlen(prompt),  | 
 | 105 | +                               tok.data(),  | 
 | 106 | +                               (int) tok.size(),  | 
 | 107 | +                               /*add_bos=*/true,  | 
 | 108 | +                               /*special=*/true);  | 
 | 109 | +    if (n_tok < 0) {  | 
 | 110 | +        max_tokens = -n_tok;  | 
 | 111 | +        tok.resize(max_tokens);  | 
 | 112 | +        n_tok = llama_tokenize(vocab, prompt, (int) strlen(prompt), tok.data(), (int) tok.size(), true, true);  | 
 | 113 | +    }  | 
 | 114 | +    if (n_tok <= 0) {  | 
 | 115 | +        fprintf(stderr, "tokenization failed\n");  | 
 | 116 | +        llama_free(ctx);  | 
 | 117 | +        llama_model_free(model);  | 
 | 118 | +        llama_backend_free();  | 
 | 119 | +        return 1;  | 
 | 120 | +    }  | 
 | 121 | +    tok.resize(n_tok);  | 
 | 122 | + | 
 | 123 | +    // 5) build batch correctly (НЕ аллоцируем seq_id вручную!)  | 
 | 124 | +    llama_batch batch = llama_batch_get_one(tok.data(), (int) tok.size());  | 
 | 125 | +    // batch.pos / batch.seq_id / batch.n_seq_id / batch.logits = nullptr  | 
 | 126 | +    // рантайм сам подставит корректные значения и вернёт логиты для последнего токена  | 
 | 127 | + | 
 | 128 | +    // 6) decode  | 
 | 129 | +    int ret = llama_decode(ctx, batch);  | 
 | 130 | +    if (ret != 0) {  | 
 | 131 | +        fprintf(stderr, "llama_decode failed, ret = %d\n", ret);  | 
 | 132 | +        llama_free(ctx);  | 
 | 133 | +        llama_model_free(model);  | 
 | 134 | +        llama_backend_free();  | 
 | 135 | +        return 1;  | 
 | 136 | +    }  | 
 | 137 | + | 
 | 138 | +    // 7) logits of the last token in the batch  | 
 | 139 | +    // (так безопаснее: это "последние" логиты, соответствующие отмеченному последнему токену)  | 
 | 140 | +    const float * logits  = llama_get_logits(ctx);  | 
 | 141 | +    const int     n_vocab = llama_vocab_n_tokens(vocab);  | 
 | 142 | + | 
 | 143 | +    // 8) softmax + top-10  | 
 | 144 | +    // найдём максимум для стабильного softmax  | 
 | 145 | +    float max_logit = logits[0];  | 
 | 146 | +    for (int i = 1; i < n_vocab; ++i) {  | 
 | 147 | +        if (logits[i] > max_logit) {  | 
 | 148 | +            max_logit = logits[i];  | 
 | 149 | +        }  | 
 | 150 | +    }  | 
 | 151 | +    // вычислим exp и сумму  | 
 | 152 | +    std::vector<float> probs(n_vocab);  | 
 | 153 | +    double             sum = 0.0;  | 
 | 154 | +    for (int i = 0; i < n_vocab; ++i) {  | 
 | 155 | +        float e  = std::exp(logits[i] - max_logit);  | 
 | 156 | +        probs[i] = e;  | 
 | 157 | +        sum += e;  | 
 | 158 | +    }  | 
 | 159 | +    for (int i = 0; i < n_vocab; ++i) {  | 
 | 160 | +        probs[i] = (float) (probs[i] / sum);  | 
 | 161 | +    }  | 
 | 162 | + | 
 | 163 | +    // соберём индексы и отсортируем по вероятности  | 
 | 164 | +    std::vector<int> ids(n_vocab);  | 
 | 165 | +    for (int i = 0; i < n_vocab; ++i) {  | 
 | 166 | +        ids[i] = i;  | 
 | 167 | +    }  | 
 | 168 | +    std::partial_sort(ids.begin(), ids.begin() + 10, ids.end(), [&](int a, int b) { return probs[a] > probs[b]; });  | 
 | 169 | + | 
 | 170 | +   // 9) распечатаем top-10  | 
 | 171 | +    char piece[256];  | 
 | 172 | +    for (int r = 0; r < 10; ++r) {  | 
 | 173 | +        int id = ids[r];  | 
 | 174 | +        int n  = llama_token_to_piece(vocab,  | 
 | 175 | +                                     id,  | 
 | 176 | +                                     piece,  | 
 | 177 | +                                     sizeof(piece),  | 
 | 178 | +                                     /*special=*/true,  | 
 | 179 | +                                     /*clean=*/true);  | 
 | 180 | +        if (n < 0) {  | 
 | 181 | +            snprintf(piece, sizeof(piece), "<tok %d>", id);  | 
 | 182 | +        } else {  | 
 | 183 | +            piece[n] = '\0';  | 
 | 184 | +        }  | 
 | 185 | +        printf("%2d) id=%6d  p=%.6f  \"%s\"\n", r + 1, id, probs[id], piece);  | 
 | 186 | +    }  | 
 | 187 | + | 
 | 188 | +    if (word != nullptr) {  | 
 | 189 | +        // 10) распечатаем ещё интересующие токены  | 
 | 190 | +        std::vector<TokenInfo> tokens_info;  | 
 | 191 | + | 
 | 192 | +        // Получаем все префиксы строки  | 
 | 193 | +        std::vector<std::string> prefixes;  | 
 | 194 | +        size_t                   text_len = strlen(word);  | 
 | 195 | +        for (size_t len = 1; len <= text_len; len++) {  | 
 | 196 | +            char buf[256];  | 
 | 197 | +            memcpy(buf, word, len);  | 
 | 198 | +            buf[len] = '\0';  | 
 | 199 | +            prefixes.push_back(buf);  | 
 | 200 | +        }  | 
 | 201 | + | 
 | 202 | +        // Проходим по словарю и ищем все токены, которые совпадают с префиксами  | 
 | 203 | +        for (int id = 0; id < llama_vocab_n_tokens(vocab); ++id) {  | 
 | 204 | +            char piece[256];  | 
 | 205 | +            int  n = llama_token_to_piece(vocab, id, piece, sizeof(piece), true, true);  | 
 | 206 | +            if (n <= 0) {  | 
 | 207 | +                continue;  | 
 | 208 | +            }  | 
 | 209 | +            piece[n] = '\0';  | 
 | 210 | + | 
 | 211 | +            // проверка на совпадение с префиксом  | 
 | 212 | +            for (const auto & pref : prefixes) {  | 
 | 213 | +                if (strcmp(piece, pref.c_str()) == 0) {  | 
 | 214 | +                    tokens_info.push_back({ id, probs[id], piece });  | 
 | 215 | +                }  | 
 | 216 | +            }  | 
 | 217 | +        }  | 
 | 218 | + | 
 | 219 | +        // Сортируем по убыванию вероятности  | 
 | 220 | +        std::sort(  | 
 | 221 | +            tokens_info.begin(), tokens_info.end(), [](const TokenInfo & a, const TokenInfo & b) { return a.p > b.p; });  | 
 | 222 | + | 
 | 223 | +        // Вывод  | 
 | 224 | +        for (const auto & t : tokens_info) {  | 
 | 225 | +            if (t.p > 0.00000049f) {  | 
 | 226 | +                printf("id=%6d  p=%.6f  \"%s\"\n", t.id, t.p, t.piece.c_str());  | 
 | 227 | +            }  | 
 | 228 | +        }  | 
 | 229 | +    }  | 
 | 230 | +      | 
 | 231 | +    // 11) cleanup  | 
 | 232 | +    llama_free(ctx);  | 
 | 233 | +    llama_model_free(model);  | 
 | 234 | +    llama_backend_free();  | 
 | 235 | +    return 0;  | 
 | 236 | +}  | 
0 commit comments