We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 37f8c7b commit 114ab63Copy full SHA for 114ab63
common/arg.cpp
@@ -963,7 +963,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
963
}
964
).set_sparam());
965
add_opt(llama_arg(
966
- {"--tfs"}, "N",
+ {"--tfs", "--tfs-z"}, "Z",
967
format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
968
[](gpt_params & params, const std::string & value) {
969
params.sparams.tfs_z = std::stof(value);
src/llama-sampling.cpp
@@ -756,20 +756,22 @@ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_tok
756
757
758
759
+ assert(cur_p->size > 0); // guaranteed earlier
760
+ size_t last_idx = cur_p->size - 1;
761
+
762
float cum_sum = 0.0f;
- size_t last_idx = cur_p->size;
763
for (size_t i = 0; i < second_derivatives.size(); ++i) {
764
cum_sum += second_derivatives[i];
765
766
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
- if (cum_sum > ctx->z && i >= ctx->min_keep) {
767
+ if (cum_sum > ctx->z && (i + 1) >= ctx->min_keep) {
768
last_idx = i;
769
break;
770
771
772
773
// Resize the output vector to keep only the tokens above the tail location
- cur_p->size = last_idx;
774
+ cur_p->size = last_idx + 1;
775
776
777
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
tests/test-sampling.cpp
@@ -271,9 +271,9 @@ int main(void) {
271
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
272
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
273
274
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
275
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
276
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f);
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
277
278
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
279
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
0 commit comments