|
1 | 1 | #include "ggml.h"
|
2 | 2 | #include "llama.h"
|
3 |
| -#include "llama-sampling.h" |
4 | 3 |
|
5 | 4 | #ifdef NDEBUG
|
6 | 5 | #undef NDEBUG
|
@@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
249 | 248 | samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
250 | 249 | }
|
251 | 250 |
|
| 251 | +static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) { |
| 252 | + std::vector<llama_token_data> cur(data.size()); |
| 253 | + std::copy(data.begin(), data.end(), cur.begin()); |
| 254 | + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; |
| 255 | + llama_sampler_apply(cnstr, &cur_p); |
| 256 | + llama_sampler_reset(cnstr); |
| 257 | + const int64_t t_start = ggml_time_us(); |
| 258 | + for (int i = 0; i < n_iter; i++) { |
| 259 | + std::copy(data.begin(), data.end(), cur.begin()); |
| 260 | + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; |
| 261 | + llama_sampler_apply(cnstr, &cur_p); |
| 262 | + llama_sampler_reset(cnstr); |
| 263 | + } |
| 264 | + const int64_t t_end = ggml_time_us(); |
| 265 | + llama_sampler_free(cnstr); |
| 266 | + printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); |
| 267 | +} |
| 268 | + |
| 269 | +#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter)) |
| 270 | + |
| 271 | +static void test_perf() { |
| 272 | + const int n_vocab = 1 << 17; |
| 273 | + |
| 274 | + std::vector<llama_token_data> data; |
| 275 | + |
| 276 | + data.reserve(n_vocab); |
| 277 | + for (int i = 0; i < n_vocab; i++) { |
| 278 | + const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f); |
| 279 | + data.emplace_back(llama_token_data{i, logit, 0.0f}); |
| 280 | + } |
| 281 | + |
| 282 | + BENCH(llama_sampler_init_top_k (40), data, 32); |
| 283 | + BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); |
| 284 | + BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); |
| 285 | + BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); |
| 286 | + BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); |
| 287 | + BENCH(llama_sampler_init_softmax (), data, 32); |
| 288 | +} |
| 289 | + |
252 | 290 | int main(void) {
|
253 | 291 | ggml_time_init();
|
254 | 292 |
|
@@ -316,5 +354,7 @@ int main(void) {
|
316 | 354 |
|
317 | 355 | printf("OK\n");
|
318 | 356 |
|
| 357 | + test_perf(); |
| 358 | + |
319 | 359 | return 0;
|
320 | 360 | }
|
0 commit comments