@@ -248,25 +248,25 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
248248 samplers_sequence.c_str (), n_vocab, top_k, top_p, min_p);
249249}
250250
251- # define BENCH ( __cnstr, __data, __n_iter ) do { \
252- auto * cnstr = (__cnstr); \
253- std::vector<llama_token_data> cur ((__data). size ()); \
254- std::copy ((__data). begin (), (__data). end (), cur.begin ()); \
255- llama_token_data_array cur_p = { cur. data (), cur. size (), - 1 , false }; \
256- llama_sampler_apply (cnstr, &cur_p); \
257- llama_sampler_reset (cnstr); \
258- const int64_t t_start = ggml_time_us (); \
259- const int n_iter = (__n_iter); \
260- for ( int i = 0 ; i < n_iter; i++) { \
261- std::copy ((__data). begin (), (__data). end (), cur. begin ()); \
262- llama_token_data_array cur_p = { cur. data (), cur. size (), - 1 , false }; \
263- llama_sampler_apply (cnstr, &cur_p); \
264- llama_sampler_reset (cnstr); \
265- } \
266- const int64_t t_end = ggml_time_us (); \
267- llama_sampler_free (cnstr); \
268- printf ( " %-42s: %8.3f us/iter \n " , #__cnstr, (t_end - t_start) / ( float )n_iter); \
269- } while ( 0 )
251+ static void bench (llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
252+ std::vector<llama_token_data> cur (data. size ());
253+ std::copy (data. begin (), data. end (), cur. begin ());
254+ llama_token_data_array cur_p = { cur. data (), cur.size (), - 1 , false };
255+ llama_sampler_apply (cnstr, &cur_p);
256+ llama_sampler_reset (cnstr);
257+ const int64_t t_start = ggml_time_us ();
258+ for ( int i = 0 ; i < n_iter; i++) {
259+ std::copy (data. begin (), data. end (), cur. begin ());
260+ llama_token_data_array cur_p = { cur. data (), cur. size (), - 1 , false };
261+ llama_sampler_apply (cnstr, &cur_p);
262+ llama_sampler_reset (cnstr);
263+ }
264+ const int64_t t_end = ggml_time_us ();
265+ llama_sampler_free (cnstr);
266+ printf ( " %-42s: %8.3f us/iter \n " , cnstr_name, ( t_end - t_start) / ( float )n_iter);
267+ }
268+
269+ # define BENCH ( __cnstr, __data, __n_iter ) bench((__cnstr), #__cnstr, (__data), (__n_iter) )
270270
271271static void test_perf () {
272272 const int n_vocab = 1 << 17 ;
0 commit comments