@@ -721,20 +721,113 @@ namespace chatllm
721721 }
722722 }
723723
724+ class LogitsPenalty
725+ {
726+ public:
727+ LogitsPenalty ()
728+ : repeat_penalty_en(false ),
729+ freq_penalty_en (false ),
730+ inv_repeat_penalty(0 .0f ), repeat_penalty(0 .0f ), freq_penalty(0 .0f ), presence_penalty(0 .0f )
731+ {}
732+
733+ LogitsPenalty (const GenerationConfig &gen_config)
734+ : repeat_penalty_en((gen_config.penalty_window > 0 ) && (gen_config.repeat_penalty != 1 .0f ) && (gen_config.repeat_penalty > 0 .0f )),
735+ freq_penalty_en((gen_config.penalty_window > 0 ) && (gen_config.frequency_penalty != 0 .0f ) || (gen_config.presence_penalty != 0 .0f )),
736+ inv_repeat_penalty(repeat_penalty_en ? 1 / gen_config.repeat_penalty : 0 .0f ),
737+ repeat_penalty(gen_config.repeat_penalty),
738+ freq_penalty(freq_penalty_en ? gen_config.frequency_penalty / gen_config.penalty_window : 0 .0f ),
739+ presence_penalty(gen_config.presence_penalty)
740+ {
741+ if (gen_config.penalty_window > 0 )
742+ {
743+ token_history.resize (gen_config.penalty_window );
744+ }
745+ reset ();
746+ }
747+
748+ virtual void skip_this (int token_id)
749+ {
750+ skip_tokens.emplace (token_id);
751+ }
752+
753+ virtual void reset ()
754+ {
755+ for (size_t i = 0 ; i < token_history.size (); i++)
756+ token_history[i] = -1 ;
757+ hist_write = 0 ;
758+ memset (token_count.data (), 0 , token_count.size () * sizeof (token_count[0 ]));
759+ }
760+
761+ virtual void accept_choice (int token_id)
762+ {
763+ if (token_history.size () < 1 ) return ;
764+ int id = token_history[hist_write];
765+ if ((0 <= id) && (id < (int )token_count.size ()))
766+ token_count[id]--;
767+ token_history[hist_write++] = token_id;
768+ if (hist_write >= token_history.size ()) hist_write = 0 ;
769+ if ((0 <= token_id) && (token_id < (int )token_count.size ()))
770+ token_count[token_id]++;
771+ }
772+
773+ virtual void process (float *logits, const int vocab_size)
774+ {
775+ if (token_history.size () < 1 ) return ;
776+
777+ if (vocab_size != (int )token_count.size ())
778+ {
779+ token_count.resize (vocab_size);
780+ }
781+
782+ for (int i = 0 ; i < vocab_size; i++)
783+ {
784+ if (repeat_penalty_en)
785+ {
786+ if (token_count[i] > 0 )
787+ logits[i] *= logits[i] > 0 ? inv_repeat_penalty : repeat_penalty;
788+ }
789+
790+ if (freq_penalty_en)
791+ logits[i] -= float (token_count[i]) * freq_penalty + float (token_count[i] > 0 ) * presence_penalty;
792+ }
793+ }
794+
795+ protected:
796+ const bool repeat_penalty_en;
797+ const bool freq_penalty_en;
798+ const float inv_repeat_penalty;
799+ const float repeat_penalty;
800+ const float freq_penalty;
801+ const float presence_penalty;
802+ std::vector<int > token_history;
803+ std::vector<int > token_count;
804+ size_t hist_write;
805+ std::set<int > skip_tokens;
806+ };
807+
724808 class Sampler
725809 {
726810 public:
727811 static const int ABORT = -1 ;
812+ Sampler () : penalty() {}
728813
814+ Sampler (const GenerationConfig &gen_config)
815+ : penalty(gen_config)
816+ {}
729817 public:
730818 virtual void seed (int x)
731819 {
732820 gen.seed ((unsigned int )x);
733821 }
734822
735- virtual void reset () {}
823+ virtual void reset ()
824+ {
825+ penalty.reset ();
826+ }
736827
737828 virtual int sampling (float *logits, const int vocab_size) = 0;
829+ public:
830+ LogitsPenalty penalty;
738831 protected:
739832 std::mt19937 gen;
740833 };
@@ -751,40 +844,26 @@ namespace chatllm
751844 class NonGreedySampler : public Sampler
752845 {
753846 public:
754- NonGreedySampler (float temperature, float presence_penalty, int top_k)
755- : inv_temp(0 .0f ), inv_presence_penalty(0 .0f ), presence_penalty(presence_penalty), top_k(top_k)
847+ NonGreedySampler (const GenerationConfig &gen_config, float temperature, int top_k)
848+ : Sampler(gen_config),
849+ inv_temp (0 .0f ), top_k(top_k)
756850 {
757851 temp_en = fabs (temperature - 1 .0f ) > 1e-5f ;
758852 if (temp_en) inv_temp = 1 .f / temperature;
759-
760- presence_penalty_en = fabs (presence_penalty - 1 .0f ) > 1e-5f ;
761- if (presence_penalty_en) inv_presence_penalty = 1 .0f / presence_penalty;
762853 }
763854
764- void reset () override
765- {
766- g.clear ();
767- }
768855
769856 int sampling (float *logits, const int vocab_size) override
770857 {
771- g.resize (vocab_size, 0 );
772- token_scores.resize (vocab_size);
773-
774858 if (temp_en)
775859 {
776860 for (int i = 0 ; i < vocab_size; i++)
777861 logits[i] *= inv_temp;
778862 }
779863
780- if (presence_penalty_en)
781- {
782- for (int i = 0 ; i < vocab_size; i++)
783- {
784- if (g[i] > 0 )
785- logits[i] *= logits[i] > 0 ? inv_presence_penalty : presence_penalty;
786- }
787- }
864+ penalty.process (logits, vocab_size);
865+
866+ token_scores.resize (vocab_size);
788867
789868 for (int i = 0 ; i < vocab_size; i++)
790869 {
@@ -813,7 +892,8 @@ namespace chatllm
813892 std::discrete_distribution<> dist (logits, logits + token_scores.size ());
814893 int next_token_id = token_scores[dist (gen)].id ;
815894
816- g[next_token_id] += 1 ;
895+ penalty.accept_choice (next_token_id);
896+
817897 return next_token_id;
818898 }
819899
@@ -846,20 +926,16 @@ namespace chatllm
846926
847927 virtual void do_sampling (float *logits, const int vocab_size) = 0;
848928 bool temp_en;
849- bool presence_penalty_en;
850929 float inv_temp;
851- float inv_presence_penalty;
852- float presence_penalty;
853930 int top_k;
854931 std::vector<TokenIdScore> token_scores;
855- std::vector<int > g;
856932 };
857933
858934 class TopPSampler : public NonGreedySampler
859935 {
860936 public:
861- TopPSampler (float temperature , float presence_penalty , int top_k, float top_p)
862- : NonGreedySampler(temperature, presence_penalty , top_k), top_p(top_p)
937+ TopPSampler (const GenerationConfig &gen_config , float temperature , int top_k, float top_p)
938+ : NonGreedySampler(gen_config, temperature , top_k), top_p(top_p)
863939 {}
864940
865941 protected:
@@ -895,8 +971,8 @@ namespace chatllm
895971 class FreeTailSampler : public NonGreedySampler
896972 {
897973 public:
898- FreeTailSampler (float temperature , float presence_penalty , int top_k, float z)
899- : NonGreedySampler(temperature, presence_penalty , top_k), z(z)
974+ FreeTailSampler (const GenerationConfig &gen_config , float temperature , int top_k, float z)
975+ : NonGreedySampler(gen_config, temperature , top_k), z(z)
900976 {}
901977
902978 protected:
@@ -952,9 +1028,9 @@ namespace chatllm
9521028 if (gen_config.do_sample )
9531029 {
9541030 if (gen_config.sampling == " top_p" )
955- r = new TopPSampler (gen_config. temperature , gen_config.presence_penalty , gen_config.top_k , gen_config.top_p );
1031+ r = new TopPSampler (gen_config, gen_config.temperature , gen_config.top_k , gen_config.top_p );
9561032 else if (gen_config.sampling == " tfs" )
957- r = new FreeTailSampler (gen_config. temperature , gen_config.presence_penalty , gen_config.top_k , gen_config.tfs_z );
1033+ r = new FreeTailSampler (gen_config, gen_config.temperature , gen_config.top_k , gen_config.tfs_z );
9581034 else if (gen_config.sampling != " greedy" )
9591035 CHATLLM_CHECK (false ) << " unknown sampling algorithm: " << gen_config.sampling ;
9601036 }
0 commit comments