Skip to content

Commit f07434f

Browse files
authored
streamline grammar sampler to speed up generation while using heavy grammar (LostRuins#1606)
1 parent ab29be5 commit f07434f

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

gpttype_adapter.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,32 +1572,35 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15721572

15731573
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
15741574
std::vector<llama_grammar_candidate> candidates_grammar;
1575+
std::vector<uint8_t> rejects;
1576+
candidates_decoded.reserve(candidates->size);
1577+
candidates_grammar.reserve(candidates->size);
1578+
rejects.assign(candidates->size, false);
15751579

15761580
for (size_t i = 0; i < candidates->size; ++i) {
15771581
const llama_token id = candidates->data[i].id;
15781582
const std::string piece = FileFormatTokenizeID(id,file_format);
15791583
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
15801584
if (found_eog) {
15811585
if (!allow_eos) {
1582-
candidates->data[i].logit = -INFINITY;
1586+
rejects[i] = true;
15831587
}
15841588
} else if (piece.empty() || piece[0] == 0) {
1585-
candidates->data[i].logit = -INFINITY;
1589+
rejects[i] = true;
15861590
} else {
15871591
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
15881592
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
15891593
}
15901594
}
15911595

1592-
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
1593-
for (const auto & reject : rejects) {
1594-
candidates->data[reject.index].logit = -INFINITY;
1596+
for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) {
1597+
rejects[reject.index] = true;
15951598
}
1596-
1599+
15971600
auto first = candidates->data;
15981601
auto last = first + candidates->size;
15991602
last = std::remove_if(first, last,
1600-
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
1603+
[&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; });
16011604
candidates->size = last - first;
16021605
}
16031606

0 commit comments

Comments
 (0)