@@ -111,6 +111,28 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
111111 }
112112}
113113
114+ static void test_xtc (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p, float t, float t_max) {
115+ const size_t n_vocab = probs.size ();
116+
117+ std::vector<llama_token_data> cur;
118+ cur.reserve (n_vocab);
119+ for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
120+ const float logit = logf (probs[token_id]);
121+ cur.emplace_back (llama_token_data{token_id, logit, 0 .0f });
122+ }
123+
124+ llama_token_data_array cur_p = { cur.data (), cur.size (), -1 , false };
125+ APPLY (llama_sampler_init_softmax (), &cur_p);
126+ DUMP (&cur_p);
127+ APPLY (llama_sampler_init_xtc (p, t, t_max, 0 , 0 ), &cur_p);
128+ DUMP (&cur_p);
129+
130+ GGML_ASSERT (cur_p.size == expected_probs.size ());
131+ for (size_t i = 0 ; i < cur_p.size ; i++) {
132+ GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-5 );
133+ }
134+ }
135+
114136static void test_typical (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p) {
115137 const size_t n_vocab = probs.size ();
116138
@@ -279,12 +301,13 @@ static void test_perf() {
279301 data.emplace_back (llama_token_data{i, logit, 0 .0f });
280302 }
281303
282- BENCH (llama_sampler_init_top_k (40 ), data, 32 );
283- BENCH (llama_sampler_init_top_p (0 .8f , 1 ), data, 32 );
284- BENCH (llama_sampler_init_min_p (0 .2f , 1 ), data, 32 );
285- BENCH (llama_sampler_init_tail_free (0 .5f , 1 ), data, 32 );
286- BENCH (llama_sampler_init_typical (0 .5f , 1 ), data, 32 );
287- BENCH (llama_sampler_init_softmax (), data, 32 );
304+ BENCH (llama_sampler_init_top_k (40 ), data, 32 );
305+ BENCH (llama_sampler_init_top_p (0 .8f , 1 ), data, 32 );
306+ BENCH (llama_sampler_init_min_p (0 .2f , 1 ), data, 32 );
307+ BENCH (llama_sampler_init_tail_free (0 .5f , 1 ), data, 32 );
308+ BENCH (llama_sampler_init_typical (0 .5f , 1 ), data, 32 );
309+ BENCH (llama_sampler_init_xtc (1 .0f , 0 .1f , 0 .8f , 1 , 1 ), data, 32 );
310+ BENCH (llama_sampler_init_softmax (), data, 32 );
288311}
289312
290313int main (void ) {
@@ -309,6 +332,11 @@ int main(void) {
309332 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
310333 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
311334
335+ test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .1f }, 0 .99f , 0 .10f , 1 .00f );
336+ test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .1f }, 0 .99f , 0 .10f , 0 .35f );
337+ test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .3f , 0 .1f }, 0 .99f , 0 .10f , 0 .25f );
338+ test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .2f , 0 .1f }, 0 .99f , 0 .20f , 0 .35f );
339+
312340 test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f }, 0 .25f );
313341 test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f , 0 .25f }, 0 .75f );
314342 test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f , 0 .25f }, 0 .99f );
0 commit comments