Skip to content

Commit 2add8c0

Browse files
committed
talk-llama : sync llama.cpp
1 parent 59494c0 commit 2add8c0

File tree

6 files changed

+597
-200
lines changed

6 files changed

+597
-200
lines changed

examples/talk-llama/llama-sampling.cpp

Lines changed: 35 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
13961396
// penalties
13971397

13981398
struct llama_sampler_penalties {
1399-
const int32_t n_vocab;
1400-
const llama_token special_eos_id;
1401-
const llama_token linefeed_id;
1402-
14031399
const int32_t penalty_last_n;
14041400
const float penalty_repeat;
14051401
const float penalty_freq;
14061402
const float penalty_present;
14071403

1408-
const bool penalize_nl;
1409-
const bool ignore_eos;
1410-
14111404
ring_buffer<llama_token> prev;
1405+
1406+
// a frequency map to count token occurrences
1407+
std::unordered_map<llama_token, int> token_count;
14121408
};
14131409

14141410
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
14211417
return;
14221418
}
14231419

1424-
ctx->prev.push_back(token);
1425-
}
1426-
1427-
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1428-
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1420+
ctx->token_count[token]++;
14291421

1430-
if (ctx->ignore_eos) {
1431-
assert(ctx->special_eos_id >= 0);
1422+
// if the ring buffer is full, remove the oldest token
1423+
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1424+
const auto old = ctx->prev.front();
14321425

1433-
// optimistically check if the candidates are not yet sorted/shuffled/truncated
1434-
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
1435-
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
1436-
} else {
1437-
// else, search for the special EOS token
1438-
for (size_t i = 0; i < cur_p->size; ++i) {
1439-
if (cur_p->data[i].id == ctx->special_eos_id) {
1440-
cur_p->data[i].logit = -INFINITY;
1441-
break;
1442-
}
1443-
}
1426+
ctx->token_count[old]--;
1427+
if (ctx->token_count[old] == 0) {
1428+
ctx->token_count.erase(old);
14441429
}
14451430
}
14461431

1447-
if ((ctx->penalty_last_n == 0) ||
1448-
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1449-
return;
1450-
}
1451-
1452-
bool nl_found = false;
1453-
size_t nl_idx = 0;
1454-
float nl_logit = -INFINITY;
1455-
if (!ctx->penalize_nl) {
1456-
assert(ctx->linefeed_id >= 0);
1432+
ctx->prev.push_back(token);
14571433

1458-
// optimistically check if the candidates are not yet sorted/shuffled/truncated
1459-
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
1460-
nl_found = true;
1461-
nl_idx = ctx->linefeed_id;
1462-
nl_logit = cur_p->data[ctx->linefeed_id].logit;
1463-
} else {
1464-
// else, search for the linefeed token
1465-
for (size_t i = 0; i < cur_p->size; ++i) {
1466-
if (cur_p->data[i].id == ctx->linefeed_id) {
1467-
nl_found = true;
1468-
nl_idx = i;
1469-
nl_logit = cur_p->data[i].logit;
1470-
break;
1471-
}
1472-
}
1473-
}
1434+
#if 0
1435+
// sanity check
1436+
std::unordered_map<llama_token, int> tmp;
1437+
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1438+
tmp[ctx->prev.rat(i)]++;
14741439
}
14751440

1476-
// Create a frequency map to count occurrences of each token in last_tokens
1477-
// TODO: optimize this by maintaining the token count in the sampler context
1478-
using llama_token_cnt = std::unordered_map<llama_token, int>;
1479-
llama_token_cnt token_count;
1441+
assert(ctx->token_count == tmp);
1442+
#endif
1443+
}
1444+
1445+
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1446+
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
14801447

1481-
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1482-
token_count[ctx->prev.rat(i)]++;
1448+
if ((ctx->penalty_last_n == 0) ||
1449+
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1450+
return;
14831451
}
14841452

14851453
// Apply frequency and presence penalties to the cur_p
14861454
for (size_t i = 0; i < cur_p->size; ++i) {
1487-
const auto token_iter = token_count.find(cur_p->data[i].id);
1488-
if (token_iter == token_count.end()) {
1455+
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1456+
if (token_iter == ctx->token_count.end()) {
14891457
continue;
14901458
}
14911459

14921460
const int count = token_iter->second;
14931461

1462+
assert(count > 0 && count <= ctx->penalty_last_n);
1463+
14941464
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
14951465
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
14961466
if (cur_p->data[i].logit <= 0) {
@@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
15031473
}
15041474

15051475
cur_p->sorted = false;
1506-
1507-
if (!ctx->penalize_nl && nl_found) {
1508-
// restore the logit of the newline token if it was penalized
1509-
cur_p->data[nl_idx].logit = nl_logit;
1510-
}
15111476
}
15121477

15131478
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
15141479
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
15151480
ctx->prev.clear();
1481+
ctx->token_count.clear();
15161482
}
15171483

15181484
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
15191485
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
15201486
auto * result = llama_sampler_init_penalties(
1521-
ctx->n_vocab,
1522-
ctx->special_eos_id,
1523-
ctx->linefeed_id,
15241487
ctx->penalty_last_n,
15251488
ctx->penalty_repeat,
15261489
ctx->penalty_freq,
1527-
ctx->penalty_present,
1528-
ctx->penalize_nl,
1529-
ctx->ignore_eos);
1490+
ctx->penalty_present);
15301491

15311492
// copy the state
15321493
{
@@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
15521513
};
15531514

15541515
struct llama_sampler * llama_sampler_init_penalties(
1555-
int32_t n_vocab,
1556-
llama_token special_eos_id,
1557-
llama_token linefeed_id,
15581516
int32_t penalty_last_n,
15591517
float penalty_repeat,
15601518
float penalty_freq,
1561-
float penalty_present,
1562-
bool penalize_nl,
1563-
bool ignore_eos) {
1564-
if (linefeed_id == LLAMA_TOKEN_NULL) {
1565-
penalize_nl = true;
1566-
}
1567-
1568-
if (special_eos_id == LLAMA_TOKEN_NULL) {
1569-
ignore_eos = false;
1570-
}
1571-
1519+
float penalty_present) {
15721520
penalty_last_n = std::max(penalty_last_n, 0);
15731521

15741522
return new llama_sampler {
15751523
/* .iface = */ &llama_sampler_penalties_i,
15761524
/* .ctx = */ new llama_sampler_penalties {
1577-
/* .n_vocab = */ n_vocab,
1578-
/* .special_eos_id = */ special_eos_id,
1579-
/* .linefeed_id = */ linefeed_id,
15801525
/* .penalty_last_n = */ penalty_last_n,
15811526
/* .penalty_repeat = */ penalty_repeat,
15821527
/* .penalty_freq = */ penalty_freq,
15831528
/* .penalty_present = */ penalty_present,
1584-
/* .penalize_nl = */ penalize_nl,
1585-
/* .ignore_eos = */ ignore_eos,
15861529
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1530+
/* .token_count = */ {},
15871531
},
15881532
};
15891533
}
@@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
16111555
if (word.find(str) != std::string::npos) {
16121556
token_sequences.emplace(token_id, std::vector<llama_token>());
16131557
} else {
1614-
size_t word_len = word.size(), str_len = str.size();
1558+
size_t word_len = word.size();
1559+
size_t str_len = str.size();
16151560
size_t pos = -1;
16161561
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
16171562
bool match = true;

examples/talk-llama/llama-vocab.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
418418
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
419419
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
420420
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
421+
case LLAMA_VOCAB_PRE_TYPE_MINERVA:
421422
regex_exprs = {
422423
"\\p{N}",
423424
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@@ -737,7 +738,7 @@ struct llm_tokenizer_wpm_session {
737738
std::vector<std::string> words(1, "");
738739

739740
for (const uint32_t cpt : cpts_nfd) {
740-
const auto flags = unicode_cpt_flags(cpt);
741+
const auto flags = unicode_cpt_flags_from_cpt(cpt);
741742

742743
if (flags.is_whitespace) {
743744
if (words.back().size()) { // finish previous word if any

0 commit comments

Comments
 (0)