Skip to content

Commit 114ab63

Browse files
committed
sampling : fix off-by-one in tail-free sampling
ggml-ci
1 parent 37f8c7b commit 114ab63

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
963963
}
964964
).set_sparam());
965965
add_opt(llama_arg(
966-
{"--tfs"}, "N",
966+
{"--tfs", "--tfs-z"}, "Z",
967967
format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
968968
[](gpt_params & params, const std::string & value) {
969969
params.sparams.tfs_z = std::stof(value);

src/llama-sampling.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,20 +756,22 @@ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_tok
756756
}
757757
}
758758

759+
assert(cur_p->size > 0); // guaranteed earlier
760+
size_t last_idx = cur_p->size - 1;
761+
759762
float cum_sum = 0.0f;
760-
size_t last_idx = cur_p->size;
761763
for (size_t i = 0; i < second_derivatives.size(); ++i) {
762764
cum_sum += second_derivatives[i];
763765

764766
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
765-
if (cum_sum > ctx->z && i >= ctx->min_keep) {
767+
if (cum_sum > ctx->z && (i + 1) >= ctx->min_keep) {
766768
last_idx = i;
767769
break;
768770
}
769771
}
770772

771773
// Resize the output vector to keep only the tokens above the tail location
772-
cur_p->size = last_idx;
774+
cur_p->size = last_idx + 1;
773775
}
774776

775777
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {

tests/test-sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,9 @@ int main(void) {
271271
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
272272
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
273273

274-
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
275-
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
276-
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
274+
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
275+
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f);
276+
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
277277

278278
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
279279
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);

0 commit comments

Comments
 (0)