|
1 | 1 | #include "ggml.h" |
2 | 2 | #include "llama.h" |
3 | | -#include "llama-sampling.h" |
4 | 3 |
|
5 | 4 | #ifdef NDEBUG |
6 | 5 | #undef NDEBUG |
@@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler |
249 | 248 | samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); |
250 | 249 | } |
251 | 250 |
|
| 251 | +#define BENCH(__cnstr, __data, __n_iter) do { \ |
| 252 | + auto * cnstr = (__cnstr); \ |
| 253 | + std::vector<llama_token_data> cur((__data).size()); \ |
| 254 | + std::copy((__data).begin(), (__data).end(), cur.begin()); \ |
| 255 | + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ |
| 256 | + llama_sampler_apply(cnstr, &cur_p); \ |
| 257 | + llama_sampler_reset(cnstr); \ |
| 258 | + const int64_t t_start = ggml_time_us(); \ |
| 259 | + const int n_iter = (__n_iter); \ |
| 260 | + for (int i = 0; i < n_iter; i++) { \ |
| 261 | + std::copy((__data).begin(), (__data).end(), cur.begin()); \ |
| 262 | + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ |
| 263 | + llama_sampler_apply(cnstr, &cur_p); \ |
| 264 | + llama_sampler_reset(cnstr); \ |
| 265 | + } \ |
| 266 | + const int64_t t_end = ggml_time_us(); \ |
| 267 | + llama_sampler_free(cnstr); \ |
| 268 | + printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \ |
| 269 | +} while(0) |
| 270 | + |
| 271 | +static void test_perf() { |
| 272 | + const int n_vocab = 1 << 17; |
| 273 | + |
| 274 | + std::vector<llama_token_data> data; |
| 275 | + |
| 276 | + data.reserve(n_vocab); |
| 277 | + for (int i = 0; i < n_vocab; i++) { |
| 278 | + const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f); |
| 279 | + data.emplace_back(llama_token_data{i, logit, 0.0f}); |
| 280 | + } |
| 281 | + |
| 282 | + BENCH(llama_sampler_init_top_k (40), data, 32); |
| 283 | + BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); |
| 284 | + BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); |
| 285 | + BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); |
| 286 | + BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); |
| 287 | + BENCH(llama_sampler_init_softmax (), data, 32); |
| 288 | +} |
| 289 | + |
252 | 290 | int main(void) { |
253 | 291 | ggml_time_init(); |
254 | 292 |
|
@@ -316,5 +354,7 @@ int main(void) { |
316 | 354 |
|
317 | 355 | printf("OK\n"); |
318 | 356 |
|
| 357 | + test_perf(); |
| 358 | + |
319 | 359 | return 0; |
320 | 360 | } |
0 commit comments