| 
1 | 1 | #include "ggml.h"  | 
2 | 2 | #include "llama.h"  | 
3 |  | -#include "llama-sampling.h"  | 
4 | 3 | 
 
  | 
5 | 4 | #ifdef NDEBUG  | 
6 | 5 | #undef NDEBUG  | 
@@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler  | 
249 | 248 |            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);  | 
250 | 249 | }  | 
251 | 250 | 
 
  | 
 | 251 | +static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {  | 
 | 252 | +    std::vector<llama_token_data> cur(data.size());  | 
 | 253 | +    std::copy(data.begin(), data.end(), cur.begin());  | 
 | 254 | +    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };  | 
 | 255 | +    llama_sampler_apply(cnstr, &cur_p);  | 
 | 256 | +    llama_sampler_reset(cnstr);  | 
 | 257 | +    const int64_t t_start = ggml_time_us();  | 
 | 258 | +    for (int i = 0; i < n_iter; i++) {  | 
 | 259 | +        std::copy(data.begin(), data.end(), cur.begin());  | 
 | 260 | +        llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };  | 
 | 261 | +        llama_sampler_apply(cnstr, &cur_p);  | 
 | 262 | +        llama_sampler_reset(cnstr);  | 
 | 263 | +    }  | 
 | 264 | +    const int64_t t_end = ggml_time_us();  | 
 | 265 | +    llama_sampler_free(cnstr);  | 
 | 266 | +    printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);  | 
 | 267 | +}  | 
 | 268 | + | 
 | 269 | +#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))  | 
 | 270 | + | 
 | 271 | +static void test_perf() {  | 
 | 272 | +    const int n_vocab = 1 << 17;  | 
 | 273 | + | 
 | 274 | +    std::vector<llama_token_data> data;  | 
 | 275 | + | 
 | 276 | +    data.reserve(n_vocab);  | 
 | 277 | +    for (int i = 0; i < n_vocab; i++) {  | 
 | 278 | +        const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);  | 
 | 279 | +        data.emplace_back(llama_token_data{i, logit, 0.0f});  | 
 | 280 | +    }  | 
 | 281 | + | 
 | 282 | +    BENCH(llama_sampler_init_top_k    (40),      data, 32);  | 
 | 283 | +    BENCH(llama_sampler_init_top_p    (0.8f, 1), data, 32);  | 
 | 284 | +    BENCH(llama_sampler_init_min_p    (0.2f, 1), data, 32);  | 
 | 285 | +    BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);  | 
 | 286 | +    BENCH(llama_sampler_init_typical  (0.5f, 1), data, 32);  | 
 | 287 | +    BENCH(llama_sampler_init_softmax  (),        data, 32);  | 
 | 288 | +}  | 
 | 289 | + | 
252 | 290 | int main(void) {  | 
253 | 291 |     ggml_time_init();  | 
254 | 292 | 
 
  | 
@@ -316,5 +354,7 @@ int main(void) {  | 
316 | 354 | 
 
  | 
317 | 355 |     printf("OK\n");  | 
318 | 356 | 
 
  | 
 | 357 | +    test_perf();  | 
 | 358 | + | 
319 | 359 |     return 0;  | 
320 | 360 | }  | 
0 commit comments