Skip to content

Commit 24d7067

Browse files
committed
talk-llama : sync llama.cpp
1 parent 5089ab2 commit 24d7067

File tree

3 files changed

+2014
-1910
lines changed

3 files changed

+2014
-1910
lines changed

examples/talk-llama/llama-sampling.cpp

Lines changed: 6 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
113113
}
114114

115115
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
116-
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
116+
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
117117
// if (k >= (int32_t)cur_p->size) {
118118
// return;
119119
// }
@@ -733,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
733733
};
734734
}
735735

736-
// tail-free
737-
738-
struct llama_sampler_tail_free {
739-
const float z;
740-
const size_t min_keep;
741-
};
742-
743-
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
744-
return "tail-free";
745-
}
746-
747-
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
748-
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
749-
750-
if (ctx->z >= 1.0f || cur_p->size <= 2) {
751-
return;
752-
}
753-
754-
llama_sampler_softmax_impl(cur_p);
755-
756-
// Compute the first and second derivatives
757-
std::vector<float> first_derivatives(cur_p->size - 1);
758-
std::vector<float> second_derivatives(cur_p->size - 2);
759-
760-
for (size_t i = 0; i < first_derivatives.size(); ++i) {
761-
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
762-
}
763-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
764-
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
765-
}
766-
767-
// Calculate absolute value of second derivatives
768-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
769-
second_derivatives[i] = std::abs(second_derivatives[i]);
770-
}
771-
772-
// Normalize the second derivatives
773-
{
774-
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
775-
776-
if (second_derivatives_sum > 1e-6f) {
777-
for (float & value : second_derivatives) {
778-
value /= second_derivatives_sum;
779-
}
780-
} else {
781-
for (float & value : second_derivatives) {
782-
value = 1.0f / second_derivatives.size();
783-
}
784-
}
785-
}
786-
787-
float cum_sum = 0.0f;
788-
size_t last_idx = cur_p->size;
789-
for (size_t i = 0; i < second_derivatives.size(); ++i) {
790-
cum_sum += second_derivatives[i];
791-
792-
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
793-
if (cum_sum > ctx->z && i >= ctx->min_keep) {
794-
last_idx = i;
795-
break;
796-
}
797-
}
798-
799-
// Resize the output vector to keep only the tokens above the tail location
800-
cur_p->size = last_idx;
801-
}
802-
803-
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
804-
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
805-
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
806-
}
807-
808-
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
809-
delete (llama_sampler_tail_free *) smpl->ctx;
810-
}
811-
812-
static struct llama_sampler_i llama_sampler_tail_free_i = {
813-
/* .name = */ llama_sampler_tail_free_name,
814-
/* .accept = */ nullptr,
815-
/* .apply = */ llama_sampler_tail_free_apply,
816-
/* .reset = */ nullptr,
817-
/* .clone = */ llama_sampler_tail_free_clone,
818-
/* .free = */ llama_sampler_tail_free_free,
819-
};
820-
821-
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
822-
return new llama_sampler {
823-
/* .iface = */ &llama_sampler_tail_free_i,
824-
/* .ctx = */ new llama_sampler_tail_free {
825-
/* .z = */ z,
826-
/*. min_keep = */ min_keep,
827-
},
828-
};
829-
}
830-
831736
// typical
832737

833738
struct llama_sampler_typical {
@@ -1971,8 +1876,11 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19711876
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
19721877
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
19731878

1974-
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
1975-
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1879+
llama_vocab dummy_vocab;
1880+
1881+
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
1882+
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1883+
19761884
// Copy the state, including the processed breakers
19771885
{
19781886
auto * result_ctx = (llama_sampler_dry *) result->ctx;

0 commit comments

Comments
 (0)