@@ -18,155 +18,165 @@ static void dump(const llama_token_data_array * cur_p) {
1818
1919#define DUMP (__cur_p ) do { printf (" %s:%d (%s)\n " , __FILE__, __LINE__, __func__); dump ((__cur_p)); printf (" -\n " ); } while (0 )
2020
21- #define APPLY (__cnstr, __cur_p ) do { \
22- auto * cnstr = (__cnstr); \
23- llama_sampler_apply (cnstr, (__cur_p)); \
24- llama_sampler_free (cnstr); \
25- } while (0 )
26-
27- #define CUR_P_FROM_PROBS () \
28- const size_t n_vocab = probs.size(); \
29- std::vector<llama_token_data> cur; \
30- cur.reserve(n_vocab); \
31- for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) { \
32- const float logit = logf (probs[token_id]); \
33- cur.emplace_back (llama_token_data{token_id, logit, 0 .0f }); \
34- } \
35- llama_token_data_array cur_p = { cur.data (), cur.size (), -1 , false }
36-
37- static void test_top_k (const std::vector<float > & probs, const std::vector<float > & expected_probs, int k) {
38- CUR_P_FROM_PROBS ();
39-
40- DUMP (&cur_p);
41- APPLY (llama_sampler_init_top_k (k), &cur_p);
42- APPLY (llama_sampler_init_dist (0 ), &cur_p);
43- DUMP (&cur_p);
44-
45- GGML_ASSERT (cur_p.size == expected_probs.size ());
46- for (size_t i = 0 ; i < cur_p.size ; i++) {
47- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-5 );
21+ struct sampler_tester {
22+ sampler_tester (size_t n_vocab) {
23+ cur.reserve (n_vocab);
24+ for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
25+ const float logit = logf (token_id);
26+ cur.emplace_back (llama_token_data{token_id, logit, 0 .0f });
27+ }
28+
29+ cur_p = llama_token_data_array { cur.data (), cur.size (), -1 , false };
4830 }
49- }
5031
51- static void test_top_p (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p) {
52- CUR_P_FROM_PROBS ();
32+ sampler_tester (const std::vector<float > & probs, const std::vector<float > & probs_expected) : probs_expected(probs_expected) {
33+ cur.reserve (probs.size ());
34+ for (llama_token token_id = 0 ; token_id < (llama_token)probs.size (); token_id++) {
35+ const float logit = logf (probs[token_id]);
36+ cur.emplace_back (llama_token_data{token_id, logit, 0 .0f });
37+ }
5338
54- DUMP (&cur_p);
55- APPLY (llama_sampler_init_top_p (p, 1 ), &cur_p);
56- APPLY (llama_sampler_init_dist (0 ), &cur_p);
57- DUMP (&cur_p);
58- DUMP (&cur_p);
39+ cur_p = llama_token_data_array { cur.data (), cur.size (), -1 , false };
40+ }
5941
60- GGML_ASSERT (cur_p. size == expected_probs. size ());
61- for ( size_t i = 0 ; i < cur_p. size ; i++) {
62- GGML_ASSERT ( fabs (cur_p. data [i]. p - expected_probs[i]) < 1e-3 );
42+ void apply (llama_sampler * sampler) {
43+ llama_sampler_apply (sampler, & cur_p);
44+ llama_sampler_free (sampler );
6345 }
46+
47+ void check () {
48+ GGML_ASSERT (cur_p.size == probs_expected.size ());
49+ for (size_t i = 0 ; i < cur_p.size ; i++) {
50+ GGML_ASSERT (fabs (cur_p.data [i].p - probs_expected[i]) < 1e-5 );
51+ }
52+ }
53+
54+ llama_token_data_array cur_p;
55+
56+ private:
57+ const std::vector<float > probs_expected;
58+
59+ std::vector<llama_token_data> cur;
60+ };
61+
62+ static void test_temp (const std::vector<float > & probs, const std::vector<float > & probs_expected, float temp) {
63+ sampler_tester tester (probs, probs_expected);
64+
65+ DUMP (&tester.cur_p );
66+ tester.apply (llama_sampler_init_temp (temp));
67+ tester.apply (llama_sampler_init_dist (0 ));
68+ DUMP (&tester.cur_p );
69+
70+ tester.check ();
6471}
6572
66- static void test_tfs (const std::vector<float > & probs, const std::vector<float > & expected_probs, float z ) {
67- CUR_P_FROM_PROBS ( );
73+ static void test_top_k (const std::vector<float > & probs, const std::vector<float > & probs_expected, int k ) {
74+ sampler_tester tester (probs, probs_expected );
6875
69- DUMP (&cur_p);
70- APPLY (llama_sampler_init_tail_free (z, 1 ), &cur_p);
71- DUMP (&cur_p);
76+ DUMP (&tester.cur_p );
77+ tester.apply (llama_sampler_init_top_k (k));
78+ tester.apply (llama_sampler_init_dist (0 ));
79+ DUMP (&tester.cur_p );
7280
73- GGML_ASSERT (cur_p.size == expected_probs.size ());
74- for (size_t i = 0 ; i < cur_p.size ; i++) {
75- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-3 );
76- }
81+ tester.check ();
7782}
7883
79- static void test_min_p (const std::vector<float > & probs, const std::vector<float > & expected_probs , float p) {
80- CUR_P_FROM_PROBS ( );
84+ static void test_top_p (const std::vector<float > & probs, const std::vector<float > & probs_expected , float p) {
85+ sampler_tester tester (probs, probs_expected );
8186
82- DUMP (&cur_p);
83- APPLY ( llama_sampler_init_min_p (p, 1 ), &cur_p );
84- APPLY (llama_sampler_init_dist (0 ), &cur_p );
85- DUMP (&cur_p);
87+ DUMP (&tester. cur_p );
88+ tester. apply ( llama_sampler_init_top_p (p, 1 ));
89+ tester. apply (llama_sampler_init_dist (0 ));
90+ DUMP (&tester. cur_p );
8691
87- GGML_ASSERT (cur_p.size == expected_probs.size ());
88- for (size_t i = 0 ; i < cur_p.size ; i++) {
89- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-3 );
90- }
92+ tester.check ();
9193}
9294
93- static void test_xtc (const std::vector<float > & probs, const std::vector<float > & expected_probs , float p, float t ) {
94- CUR_P_FROM_PROBS ( );
95+ static void test_tfs (const std::vector<float > & probs, const std::vector<float > & probs_expected , float z ) {
96+ sampler_tester tester (probs, probs_expected );
9597
96- DUMP (&cur_p);
97- APPLY ( llama_sampler_init_xtc (p, t, 0 , 0 ), &cur_p );
98- DUMP (&cur_p);
98+ DUMP (&tester. cur_p );
99+ tester. apply ( llama_sampler_init_tail_free (z, 1 ) );
100+ DUMP (&tester. cur_p );
99101
100- GGML_ASSERT (cur_p.size == expected_probs.size ());
101- for (size_t i = 0 ; i < cur_p.size ; i++) {
102- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-5 );
103- }
102+ tester.check ();
104103}
105104
106- static void test_typical (const std::vector<float > & probs, const std::vector<float > & expected_probs , float p) {
107- CUR_P_FROM_PROBS ( );
105+ static void test_min_p (const std::vector<float > & probs, const std::vector<float > & probs_expected , float p) {
106+ sampler_tester tester (probs, probs_expected );
108107
109- DUMP (&cur_p);
110- APPLY (llama_sampler_init_typical (p, 1 ), &cur_p);
111- DUMP (&cur_p);
108+ DUMP (&tester.cur_p );
109+ tester.apply (llama_sampler_init_min_p (p, 1 ));
110+ tester.apply (llama_sampler_init_dist (0 ));
111+ DUMP (&tester.cur_p );
112112
113- GGML_ASSERT (cur_p.size == expected_probs.size ());
114- for (size_t i = 0 ; i < cur_p.size ; i++) {
115- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-3 );
116- }
113+ tester.check ();
114+ }
115+
116+ static void test_xtc (const std::vector<float > & probs, const std::vector<float > & probs_expected, float p, float t) {
117+ sampler_tester tester (probs, probs_expected);
118+
119+ DUMP (&tester.cur_p );
120+ tester.apply (llama_sampler_init_xtc (p, t, 0 , 0 ));
121+ DUMP (&tester.cur_p );
122+
123+ tester.check ();
124+ }
125+
126+ static void test_typical (const std::vector<float > & probs, const std::vector<float > & probs_expected, float p) {
127+ sampler_tester tester (probs, probs_expected);
128+
129+ DUMP (&tester.cur_p );
130+ tester.apply (llama_sampler_init_typical (p, 1 ));
131+ DUMP (&tester.cur_p );
132+
133+ tester.check ();
117134}
118135
119136static void test_penalties (
120137 const std::vector<float > & probs, const std::vector<llama_token> & last_tokens,
121- const std::vector<float > & expected_probs , float repeat_penalty, float alpha_frequency, float alpha_presence
138+ const std::vector<float > & probs_expected , float repeat_penalty, float alpha_frequency, float alpha_presence
122139) {
123- GGML_ASSERT (probs.size () == expected_probs .size ());
140+ GGML_ASSERT (probs.size () == probs_expected .size ());
124141
125- CUR_P_FROM_PROBS ( );
142+ sampler_tester tester (probs, probs_expected );
126143
144+ const size_t n_vocab = probs.size ();
127145 auto * sampler = llama_sampler_init_penalties (n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size (), repeat_penalty, alpha_frequency, alpha_presence, false , false );
128146
129147 for (size_t i = 0 ; i < last_tokens.size (); i++) {
130148 llama_sampler_accept (sampler, last_tokens[i]);
131149 }
132150
133- DUMP (&cur_p);
134- APPLY (sampler, &cur_p );
135- APPLY (llama_sampler_init_dist (0 ), &cur_p );
136- DUMP (&cur_p);
151+ DUMP (&tester. cur_p );
152+ tester. apply (sampler);
153+ tester. apply (llama_sampler_init_dist (0 ));
154+ DUMP (&tester. cur_p );
137155
138- GGML_ASSERT (cur_p.size == expected_probs.size ());
139- for (size_t i = 0 ; i < cur_p.size ; i++) {
140- GGML_ASSERT (fabs (cur_p.data [i].p - expected_probs[i]) < 1e-3 );
141- }
156+ tester.check ();
142157}
143158
144159static void test_sampler_queue (const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
145160) {
146- std::vector<llama_token_data> cur;
147- cur.reserve (n_vocab);
148- for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
149- const float logit = logf (token_id);
150- cur.emplace_back (llama_token_data{token_id, logit, 0 .0f });
151- }
152-
153- llama_token_data_array cur_p = { cur.data (), cur.size (), -1 , false };
161+ sampler_tester tester (n_vocab);
154162
155163 llama_token min_token_id = 0 ;
156164 const llama_token max_token_id = n_vocab-1 ;
157165
158166 for (auto s : samplers_sequence) {
159167 switch (s){
160- case ' k' : APPLY (llama_sampler_init_top_k (top_k), &cur_p ); break ;
168+ case ' k' : tester. apply (llama_sampler_init_top_k (top_k)); break ;
161169 case ' f' : GGML_ABORT (" tail_free test not implemented" );
162170 case ' y' : GGML_ABORT (" typical test not implemented" );
163- case ' p' : APPLY (llama_sampler_init_top_p (top_p, 1 ), &cur_p ); break ;
164- case ' m' : APPLY (llama_sampler_init_min_p (min_p, 1 ), &cur_p ); break ;
171+ case ' p' : tester. apply (llama_sampler_init_top_p (top_p, 1 )); break ;
172+ case ' m' : tester. apply (llama_sampler_init_min_p (min_p, 1 )); break ;
165173 case ' t' : GGML_ABORT (" temperature test not implemented" );
166174 default : GGML_ABORT (" Unknown sampler" );
167175 }
168176
169- APPLY (llama_sampler_init_dist (0 ), &cur_p);
177+ tester.apply (llama_sampler_init_dist (0 ));
178+
179+ auto & cur_p = tester.cur_p ;
170180
171181 const int size = cur_p.size ;
172182
0 commit comments