Skip to content

Commit 0180f53

Browse files
committed
more xxxx_penalty options
1 parent af1f95b commit 0180f53

File tree

3 files changed

+133
-41
lines changed

3 files changed

+133
-41
lines changed

src/chat.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,12 @@ namespace chatllm
791791
bool do_sample;
792792
bool reversed_role;
793793
int top_k;
794+
int penalty_window;
794795
float top_p;
795796
float temperature;
796797
float presence_penalty;
798+
float repeat_penalty;
799+
float frequency_penalty;
797800
float tfs_z;
798801
std::string sampling;
799802
std::string ai_prefix;

src/main.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ struct Args
6666
float top_p = 0.7f;
6767
float temp = 0.7f;
6868
float tfs_z = 0.95f;
69-
float presence_penalty = 1.0f;
69+
float presence_penalty = 0.0f;
70+
float repeat_penalty = 1.0f;
71+
float frequency_penalty = 0.0f;
7072
int num_threads = 0;
7173
bool multi_line = false;
7274
int seed;
@@ -90,6 +92,7 @@ struct Args
9092
bool moe_on_cpu = false;
9193
int batch_size = 4096;
9294
bool detect_thoughts = false;
95+
int penalty_window = 256;
9396
};
9497

9598
#define MULTI_LINE_END_MARKER_W L"\\."
@@ -153,6 +156,7 @@ bool is_same_command_option(const std::string &a, const std::string &b)
153156

154157
void usage(const std::string &prog)
155158
{
159+
Args args;
156160
std::cout << "Usage: " << prog << " [options]\n"
157161
<< "\n"
158162
<< "Basic options:\n"
@@ -176,7 +180,7 @@ void usage(const std::string &prog)
176180
<< " --layer_spec 0:3,1:4 (3 + 3 = 6 layers are selected, layer #1/2 are used twice)\n"
177181
<< " layer structure: 0->1->2->1->2->3\n"
178182
<< " -c, --max_context_length N\n"
179-
<< " max context length (default: 512)\n"
183+
<< " max context length (default: " << args.max_context_length << ")\n"
180184
<< " --extending EXT context extending method (EXT = restart | shift | none)\n"
181185
<< " (default: none if `--load_session` is specified, otherwise restart)\n"
182186
<< " --multi enabled multiple lines of input [*]\n"
@@ -197,18 +201,21 @@ void usage(const std::string &prog)
197201
<< " --rpc_endpoints EP.. RPC endpoints (i.e. servers) for distributed inference (default: empty)\n"
198202
<< " EP1;EP2, where EP ::= host:port\n"
199203
<< " --cache_dtype T cache data type, T ::= f32 | f16 (default: f16)\n"
200-
<< " --batch_size N batch size (default: 4096)\n"
204+
<< " --batch_size N batch size (default: " << args.batch_size << ")\n"
201205
<< " note: trade-off between prompt throughput and memory usage.\n"
202206
<< " --re_quantize Q re-quantize model weights during loading (Q ::= q8_0 | q4_0 | q4_1 | q4_k | ...) (default: no re-quantization)\n"
203207
<< " note: it does not make sense to re-quantize to a larger size.\n"
204208
<< "Sampling options:\n"
205209
<< " --sampling ALG sampling algorithm (ALG = greedy | top_p | tfs) (default: top_p) \n"
206210
<< " where, tfs = Tail Free Sampling\n"
207-
<< " -t, --temp T temperature (default: 0.7) (Note: `-t 0` also sets sampling algorithm to greedy)\n"
208-
<< " --top_k N top-k sampling (default: 20)\n"
209-
<< " --top_p N top-p sampling (default: 0.7)\n"
210-
<< " --tfs_z Z Z param for TFS (default: 0.95)\n"
211-
<< " --presence_penalty N presence repetition penalty (default: 1.0, no penalty)\n"
211+
<< " -t, --temp T temperature (default: " << args.temp << ") (Note: `-t 0` also sets sampling algorithm to greedy)\n"
212+
<< " --top_k N top-k sampling (default: " << args.top_k << ")\n"
213+
<< " --top_p N top-p sampling (default: " << args.top_p << ")\n"
214+
<< " --tfs_z Z Z param for TFS (default: " << args.tfs_z << ")\n"
215+
<< " --repeat_penalty N repetition penalty (default: " << args.repeat_penalty << ", 1.0=no penalty)\n"
216+
<< " --presence_penalty N penalty alpha for presence (default: " << args.presence_penalty << ", 0.0=disabled)\n"
217+
<< " --frequency_penalty N penalty alpha for probability (default: " << args.frequency_penalty << ", 0.0=disabled)\n"
218+
<< " --penalty_window N last N tokens to consider for penalize (default: " << args.penalty_window << ", 0=disable all)\n"
212219
<< " --seed N seed for random generator (default: random)\n"
213220
<< " --beam_size N beam size for generation (default: -1, disabled)\n"
214221
<< " functionality of beam search limited.\n"
@@ -465,6 +472,9 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
465472
handle_para0("--tfs_z", tfs_z, std::stof)
466473
handle_param("--temp", "-t", temp, std::stof)
467474
handle_para0("--presence_penalty", presence_penalty, std::stof)
475+
handle_para0("--repeat_penalty", repeat_penalty, std::stof)
476+
handle_para0("--frequency_penalty", frequency_penalty, std::stof)
477+
handle_para0("--penalty_window", penalty_window, std::stoi)
468478
handle_param("--threads", "-n", num_threads, std::stoi)
469479
handle_para0("--seed", seed, std::stoi)
470480
handle_para0("--test", test_fn, std::string)
@@ -852,7 +862,10 @@ static void run_qa_ranker(Args &args, chatllm::Pipeline &pipeline, TextStreamer
852862
#define DEF_GenerationConfig(gen_config, args) chatllm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.reversed_role, \
853863
args.top_k, args.top_p, args.temp, args.num_threads, args.sampling, args.presence_penalty, args.tfs_z); \
854864
gen_config.set_ai_prefix(args.ai_prefix); gen_config.dump_dot = args.dump_dot; \
855-
gen_config.emb_rank_query_sep = args.emb_rank_query_sep;
865+
gen_config.emb_rank_query_sep = args.emb_rank_query_sep; \
866+
gen_config.repeat_penalty = args.repeat_penalty; \
867+
gen_config.frequency_penalty = args.frequency_penalty; \
868+
gen_config.penalty_window = args.penalty_window;
856869

857870
#define DEF_ExtraArgs(pipe_args, args) \
858871
chatllm::ModelObject::extra_args pipe_args(args.max_length, args.layer_spec, args.moe_on_cpu, args.num_threads, args.batch_size, args.cache_dtype, args.re_quantize);\

src/models.cpp

Lines changed: 108 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -721,20 +721,113 @@ namespace chatllm
721721
}
722722
}
723723

724+
class LogitsPenalty
725+
{
726+
public:
727+
LogitsPenalty()
728+
: repeat_penalty_en(false),
729+
freq_penalty_en(false),
730+
inv_repeat_penalty(0.0f), repeat_penalty(0.0f), freq_penalty(0.0f), presence_penalty(0.0f)
731+
{}
732+
733+
LogitsPenalty(const GenerationConfig &gen_config)
734+
: repeat_penalty_en((gen_config.penalty_window > 0) && (gen_config.repeat_penalty != 1.0f) && (gen_config.repeat_penalty > 0.0f)),
735+
freq_penalty_en((gen_config.penalty_window > 0) && (gen_config.frequency_penalty != 0.0f) || (gen_config.presence_penalty != 0.0f)),
736+
inv_repeat_penalty(repeat_penalty_en ? 1 / gen_config.repeat_penalty : 0.0f),
737+
repeat_penalty(gen_config.repeat_penalty),
738+
freq_penalty(freq_penalty_en ? gen_config.frequency_penalty / gen_config.penalty_window : 0.0f),
739+
presence_penalty(gen_config.presence_penalty)
740+
{
741+
if (gen_config.penalty_window > 0)
742+
{
743+
token_history.resize(gen_config.penalty_window);
744+
}
745+
reset();
746+
}
747+
748+
virtual void skip_this(int token_id)
749+
{
750+
skip_tokens.emplace(token_id);
751+
}
752+
753+
virtual void reset()
754+
{
755+
for (size_t i = 0; i < token_history.size(); i++)
756+
token_history[i] = -1;
757+
hist_write = 0;
758+
memset(token_count.data(), 0, token_count.size() * sizeof(token_count[0]));
759+
}
760+
761+
virtual void accept_choice(int token_id)
762+
{
763+
if (token_history.size() < 1) return;
764+
int id = token_history[hist_write];
765+
if ((0 <= id) && (id < (int)token_count.size()))
766+
token_count[id]--;
767+
token_history[hist_write++] = token_id;
768+
if (hist_write >= token_history.size()) hist_write = 0;
769+
if ((0 <= token_id) && (token_id < (int)token_count.size()))
770+
token_count[token_id]++;
771+
}
772+
773+
virtual void process(float *logits, const int vocab_size)
774+
{
775+
if (token_history.size() < 1) return;
776+
777+
if (vocab_size != (int)token_count.size())
778+
{
779+
token_count.resize(vocab_size);
780+
}
781+
782+
for (int i = 0; i < vocab_size; i++)
783+
{
784+
if (repeat_penalty_en)
785+
{
786+
if (token_count[i] > 0)
787+
logits[i] *= logits[i] > 0 ? inv_repeat_penalty : repeat_penalty;
788+
}
789+
790+
if (freq_penalty_en)
791+
logits[i] -= float(token_count[i]) * freq_penalty + float(token_count[i] > 0) * presence_penalty;
792+
}
793+
}
794+
795+
protected:
796+
const bool repeat_penalty_en;
797+
const bool freq_penalty_en;
798+
const float inv_repeat_penalty;
799+
const float repeat_penalty;
800+
const float freq_penalty;
801+
const float presence_penalty;
802+
std::vector<int> token_history;
803+
std::vector<int> token_count;
804+
size_t hist_write;
805+
std::set<int> skip_tokens;
806+
};
807+
724808
class Sampler
725809
{
726810
public:
727811
static const int ABORT = -1;
812+
Sampler() : penalty() {}
728813

814+
Sampler(const GenerationConfig &gen_config)
815+
: penalty(gen_config)
816+
{}
729817
public:
730818
virtual void seed(int x)
731819
{
732820
gen.seed((unsigned int)x);
733821
}
734822

735-
virtual void reset() {}
823+
virtual void reset()
824+
{
825+
penalty.reset();
826+
}
736827

737828
virtual int sampling(float *logits, const int vocab_size) = 0;
829+
public:
830+
LogitsPenalty penalty;
738831
protected:
739832
std::mt19937 gen;
740833
};
@@ -751,40 +844,26 @@ namespace chatllm
751844
class NonGreedySampler: public Sampler
752845
{
753846
public:
754-
NonGreedySampler(float temperature, float presence_penalty, int top_k)
755-
: inv_temp(0.0f), inv_presence_penalty(0.0f), presence_penalty(presence_penalty), top_k(top_k)
847+
NonGreedySampler(const GenerationConfig &gen_config, float temperature, int top_k)
848+
: Sampler(gen_config),
849+
inv_temp(0.0f), top_k(top_k)
756850
{
757851
temp_en = fabs(temperature - 1.0f) > 1e-5f;
758852
if (temp_en) inv_temp = 1.f / temperature;
759-
760-
presence_penalty_en = fabs(presence_penalty - 1.0f) > 1e-5f;
761-
if (presence_penalty_en) inv_presence_penalty = 1.0f / presence_penalty;
762853
}
763854

764-
void reset() override
765-
{
766-
g.clear();
767-
}
768855

769856
int sampling(float *logits, const int vocab_size) override
770857
{
771-
g.resize(vocab_size, 0);
772-
token_scores.resize(vocab_size);
773-
774858
if (temp_en)
775859
{
776860
for (int i = 0; i < vocab_size; i++)
777861
logits[i] *= inv_temp;
778862
}
779863

780-
if (presence_penalty_en)
781-
{
782-
for (int i = 0; i < vocab_size; i++)
783-
{
784-
if (g[i] > 0)
785-
logits[i] *= logits[i] > 0 ? inv_presence_penalty : presence_penalty;
786-
}
787-
}
864+
penalty.process(logits, vocab_size);
865+
866+
token_scores.resize(vocab_size);
788867

789868
for (int i = 0; i < vocab_size; i++)
790869
{
@@ -813,7 +892,8 @@ namespace chatllm
813892
std::discrete_distribution<> dist(logits, logits + token_scores.size());
814893
int next_token_id = token_scores[dist(gen)].id;
815894

816-
g[next_token_id] += 1;
895+
penalty.accept_choice(next_token_id);
896+
817897
return next_token_id;
818898
}
819899

@@ -846,20 +926,16 @@ namespace chatllm
846926

847927
virtual void do_sampling(float *logits, const int vocab_size) = 0;
848928
bool temp_en;
849-
bool presence_penalty_en;
850929
float inv_temp;
851-
float inv_presence_penalty;
852-
float presence_penalty;
853930
int top_k;
854931
std::vector<TokenIdScore> token_scores;
855-
std::vector<int> g;
856932
};
857933

858934
class TopPSampler : public NonGreedySampler
859935
{
860936
public:
861-
TopPSampler(float temperature, float presence_penalty, int top_k, float top_p)
862-
: NonGreedySampler(temperature, presence_penalty, top_k), top_p(top_p)
937+
TopPSampler(const GenerationConfig &gen_config, float temperature, int top_k, float top_p)
938+
: NonGreedySampler(gen_config, temperature, top_k), top_p(top_p)
863939
{}
864940

865941
protected:
@@ -895,8 +971,8 @@ namespace chatllm
895971
class FreeTailSampler : public NonGreedySampler
896972
{
897973
public:
898-
FreeTailSampler(float temperature, float presence_penalty, int top_k, float z)
899-
: NonGreedySampler(temperature, presence_penalty, top_k), z(z)
974+
FreeTailSampler(const GenerationConfig &gen_config, float temperature, int top_k, float z)
975+
: NonGreedySampler(gen_config, temperature, top_k), z(z)
900976
{}
901977

902978
protected:
@@ -952,9 +1028,9 @@ namespace chatllm
9521028
if (gen_config.do_sample)
9531029
{
9541030
if (gen_config.sampling == "top_p")
955-
r = new TopPSampler(gen_config.temperature, gen_config.presence_penalty, gen_config.top_k, gen_config.top_p);
1031+
r = new TopPSampler(gen_config, gen_config.temperature, gen_config.top_k, gen_config.top_p);
9561032
else if (gen_config.sampling == "tfs")
957-
r = new FreeTailSampler(gen_config.temperature, gen_config.presence_penalty, gen_config.top_k, gen_config.tfs_z);
1033+
r = new FreeTailSampler(gen_config, gen_config.temperature, gen_config.top_k, gen_config.tfs_z);
9581034
else if (gen_config.sampling != "greedy")
9591035
CHATLLM_CHECK(false) << "unknown sampling algorithm: " << gen_config.sampling;
9601036
}

0 commit comments

Comments
 (0)