@@ -113,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
113113}
114114
115115static void llama_sampler_top_k_impl (llama_token_data_array * cur_p, int32_t k) {
116- // TODO: move bucket sort to separate function so that top_p/tail_free/ typical/softmax first is equally fast
116+ // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
117117 // if (k >= (int32_t)cur_p->size) {
118118 // return;
119119 // }
@@ -733,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
733733 };
734734}
735735
736- // tail-free
737-
738- struct llama_sampler_tail_free {
739- const float z;
740- const size_t min_keep;
741- };
742-
743- static const char * llama_sampler_tail_free_name (const struct llama_sampler * /* smpl*/ ) {
744- return " tail-free" ;
745- }
746-
747- static void llama_sampler_tail_free_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
748- const auto * ctx = (llama_sampler_tail_free *) smpl->ctx ;
749-
750- if (ctx->z >= 1 .0f || cur_p->size <= 2 ) {
751- return ;
752- }
753-
754- llama_sampler_softmax_impl (cur_p);
755-
756- // Compute the first and second derivatives
757- std::vector<float > first_derivatives (cur_p->size - 1 );
758- std::vector<float > second_derivatives (cur_p->size - 2 );
759-
760- for (size_t i = 0 ; i < first_derivatives.size (); ++i) {
761- first_derivatives[i] = cur_p->data [i].p - cur_p->data [i + 1 ].p ;
762- }
763- for (size_t i = 0 ; i < second_derivatives.size (); ++i) {
764- second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1 ];
765- }
766-
767- // Calculate absolute value of second derivatives
768- for (size_t i = 0 ; i < second_derivatives.size (); ++i) {
769- second_derivatives[i] = std::abs (second_derivatives[i]);
770- }
771-
772- // Normalize the second derivatives
773- {
774- const float second_derivatives_sum = std::accumulate (second_derivatives.begin (), second_derivatives.end (), 0 .0f );
775-
776- if (second_derivatives_sum > 1e-6f ) {
777- for (float & value : second_derivatives) {
778- value /= second_derivatives_sum;
779- }
780- } else {
781- for (float & value : second_derivatives) {
782- value = 1 .0f / second_derivatives.size ();
783- }
784- }
785- }
786-
787- float cum_sum = 0 .0f ;
788- size_t last_idx = cur_p->size ;
789- for (size_t i = 0 ; i < second_derivatives.size (); ++i) {
790- cum_sum += second_derivatives[i];
791-
792- // Check if the running sum is greater than z or if we have kept at least min_keep tokens
793- if (cum_sum > ctx->z && i >= ctx->min_keep ) {
794- last_idx = i;
795- break ;
796- }
797- }
798-
799- // Resize the output vector to keep only the tokens above the tail location
800- cur_p->size = last_idx;
801- }
802-
803- static struct llama_sampler * llama_sampler_tail_free_clone (const struct llama_sampler * smpl) {
804- const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx ;
805- return llama_sampler_init_tail_free (ctx->z , ctx->min_keep );
806- }
807-
808- static void llama_sampler_tail_free_free (struct llama_sampler * smpl) {
809- delete (llama_sampler_tail_free *) smpl->ctx ;
810- }
811-
812- static struct llama_sampler_i llama_sampler_tail_free_i = {
813- /* .name = */ llama_sampler_tail_free_name,
814- /* .accept = */ nullptr ,
815- /* .apply = */ llama_sampler_tail_free_apply,
816- /* .reset = */ nullptr ,
817- /* .clone = */ llama_sampler_tail_free_clone,
818- /* .free = */ llama_sampler_tail_free_free,
819- };
820-
821- struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep) {
822- return new llama_sampler {
823- /* .iface = */ &llama_sampler_tail_free_i,
824- /* .ctx = */ new llama_sampler_tail_free {
825- /* .z = */ z,
826- /* . min_keep = */ min_keep,
827- },
828- };
829- }
830-
831736// typical
832737
833738struct llama_sampler_typical {
@@ -1971,8 +1876,11 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19711876static struct llama_sampler * llama_sampler_dry_clone (const struct llama_sampler * smpl) {
19721877 const auto * ctx = (llama_sampler_dry *) smpl->ctx ;
19731878
1974- // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
1975- auto * result = llama_sampler_init_dry (nullptr , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n , NULL , 0 );
1879+ llama_vocab dummy_vocab;
1880+
1881+ // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
1882+ auto * result = llama_sampler_init_dry_impl (dummy_vocab, ctx->total_context_size , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n , NULL , 0 );
1883+
19761884 // Copy the state, including the processed breakers
19771885 {
19781886 auto * result_ctx = (llama_sampler_dry *) result->ctx ;
0 commit comments