@@ -72,11 +72,11 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
7272 tester.check ();
7373}
7474
75- static void test_temp_ext (const std::vector<float > & probs, const std::vector<float > & probs_expected, float temp, float delta, float exponent) {
75+ static void test_temp_ext (const std::vector<float > & probs, const std::vector<float > & probs_expected, float temp, float delta, float exponent, float smoothing_factor, float smoothing_curve ) {
7676 sampler_tester tester (probs, probs_expected);
7777
7878 DUMP (&tester.cur_p );
79- tester.apply (llama_sampler_init_temp_ext (temp, delta, exponent));
79+ tester.apply (llama_sampler_init_temp_ext (temp, delta, exponent, smoothing_factor, smoothing_curve ));
8080 tester.apply (llama_sampler_init_dist (0 ));
8181 DUMP (&tester.cur_p );
8282
@@ -311,8 +311,11 @@ int main(void) {
311311 test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f );
312312 test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f );
313313
314- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f , 0 .0f , 1 .0f );
315- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f , 0 .0f , 1 .0f );
314+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f , 0 .0f , 1 .0f , 0 .0f , 1 .0f );
315+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f , 0 .0f , 1 .0f , 0 .0f , 1 .0f );
316+
317+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .372382f , 0 .342804f , 0 .230319f , 0 .054495f }, 1 .0f , 0 .0f , 1 .0f , 1 .0f , 1 .0f );
318+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .368339f , 0 .349226f , 0 .245247f , 0 .037188f }, 1 .0f , 0 .0f , 1 .0f , 1 .0f , 2 .0f );
316319
317320 test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f }, 1 );
318321 test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .44444f , 0 .33333f , 0 .22222f }, 3 );
0 commit comments