Skip to content

Commit dbe9ef7

Browse files
authored
Added XTC to test-sampling
1 parent 4c44e3d commit dbe9ef7

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

tests/test-sampling.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,28 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
111111
}
112112
}
113113

114+
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t, float t_max) {
115+
const size_t n_vocab = probs.size();
116+
117+
std::vector<llama_token_data> cur;
118+
cur.reserve(n_vocab);
119+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
120+
const float logit = logf(probs[token_id]);
121+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
122+
}
123+
124+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
125+
APPLY(llama_sampler_init_softmax(), &cur_p);
126+
DUMP(&cur_p);
127+
APPLY(llama_sampler_init_xtc(p, t, t_max, 0, 0), &cur_p);
128+
DUMP(&cur_p);
129+
130+
GGML_ASSERT(cur_p.size == expected_probs.size());
131+
for (size_t i = 0; i < cur_p.size; i++) {
132+
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
133+
}
134+
}
135+
114136
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
115137
const size_t n_vocab = probs.size();
116138

@@ -279,12 +301,13 @@ static void test_perf() {
279301
data.emplace_back(llama_token_data{i, logit, 0.0f});
280302
}
281303

282-
BENCH(llama_sampler_init_top_k (40), data, 32);
283-
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
284-
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
285-
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
286-
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
287-
BENCH(llama_sampler_init_softmax (), data, 32);
304+
BENCH(llama_sampler_init_top_k (40), data, 32);
305+
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
306+
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
307+
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
308+
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
309+
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 0.8f, 1, 1), data, 32);
310+
BENCH(llama_sampler_init_softmax (), data, 32);
288311
}
289312

290313
int main(void) {
@@ -309,6 +332,11 @@ int main(void) {
309332
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
310333
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
311334

335+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
336+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
337+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
338+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
339+
312340
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
313341
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
314342
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);

0 commit comments

Comments
 (0)