@@ -98,7 +98,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
9898 sampler_tester tester (probs, probs_expected);
9999
100100 DUMP (&tester.cur_p );
101- tester.apply (llama_sampler_init_top_p (p, 1 ));
101+ tester.apply (llama_sampler_init_top_p (p, 0 ));
102102 tester.apply (llama_sampler_init_dist (0 ));
103103 DUMP (&tester.cur_p );
104104
@@ -109,7 +109,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
109109 sampler_tester tester (probs, probs_expected);
110110
111111 DUMP (&tester.cur_p );
112- tester.apply (llama_sampler_init_min_p (p, 1 ));
112+ tester.apply (llama_sampler_init_min_p (p, 0 ));
113113 tester.apply (llama_sampler_init_dist (0 ));
114114 DUMP (&tester.cur_p );
115115
@@ -130,7 +130,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
130130 sampler_tester tester (probs, probs_expected);
131131
132132 DUMP (&tester.cur_p );
133- tester.apply (llama_sampler_init_typical (p, 1 ));
133+ tester.apply (llama_sampler_init_typical (p, 0 ));
134134 DUMP (&tester.cur_p );
135135
136136 tester.check ();
@@ -332,6 +332,7 @@ int main(void) {
332332 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .74f );
333333 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
334334 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
335+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .05f );
335336
336337 printf (" XTC should:\n " );
337338 test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .1f }, 0 .99f , 0 .09f );
@@ -341,8 +342,8 @@ int main(void) {
341342 printf (" XTC should not:\n " );
342343 test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 .99f , 0 .39f );
343344
344- test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
345- test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
345+ test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346+ test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
346347
347348 test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .25f , 0 .25f , 0 .25f , 0 .25f , 0 }, 50 .0f , 0 .0f , 0 .0f );
348349 test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .5f , 0 .5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
0 commit comments