@@ -182,6 +182,17 @@ static void test_dry(
182182 tester.check ();
183183}
184184
185+ static void test_top_n_sigma (const std::vector<float > & probs, const std::vector<float > & probs_expected, int n) {
186+ sampler_tester tester (probs, probs_expected);
187+
188+ DUMP (&tester.cur_p );
189+ tester.apply (llama_sampler_init_top_n_sigma (n));
190+ tester.apply (llama_sampler_init_dist (0 ));
191+ DUMP (&tester.cur_p );
192+
193+ tester.check ();
194+ }
195+
185196static void test_sampler_queue (const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
186197) {
187198 sampler_tester tester (n_vocab);
@@ -349,6 +360,14 @@ int main(void) {
349360 test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 1 }, {0 .241818f , 0 .241818f , 0 .241818f , 0 .241818f , 0 .032727f }, 2 .0f , 1 .1f , 2 , 5 , {});
350361 test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 3 , 4 , 0 , 1 }, {0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, 1 .0f , 1 .1f , 4 , 7 , {});
351362
363+ test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .571429f , 0 .428571f , 0 .0f , 0 .0f }, 1 );
364+ test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 );
365+ test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 3 );
366+
367+ // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
368+ // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
369+ // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
370+
352371 test_sampler_queue (10000 , " k" , 10000 , 1 .0f , 1 .0f );
353372 test_sampler_queue (10000 , " k" , 1 , 1 .0f , 1 .0f );
354373 test_sampler_queue (10000 , " p" , 10000 , 1 .0f , 1 .0f );
0 commit comments