@@ -1572,32 +1572,35 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15721572
15731573 std::vector<std::pair<std::vector<uint32_t >, llama_partial_utf8>> candidates_decoded;
15741574 std::vector<llama_grammar_candidate> candidates_grammar;
1575+ std::vector<uint8_t > rejects;
1576+ candidates_decoded.reserve (candidates->size );
1577+ candidates_grammar.reserve (candidates->size );
1578+ rejects.assign (candidates->size , false );
15751579
15761580 for (size_t i = 0 ; i < candidates->size ; ++i) {
15771581 const llama_token id = candidates->data [i].id ;
15781582 const std::string piece = FileFormatTokenizeID (id,file_format);
15791583 bool found_eog = std::find (eog_tokens.begin (), eog_tokens.end (), id) != eog_tokens.end ();
15801584 if (found_eog) {
15811585 if (!allow_eos) {
1582- candidates-> data [i]. logit = -INFINITY ;
1586+ rejects [i] = true ;
15831587 }
15841588 } else if (piece.empty () || piece[0 ] == 0 ) {
1585- candidates-> data [i]. logit = -INFINITY ;
1589+ rejects [i] = true ;
15861590 } else {
15871591 candidates_decoded.push_back (decode_utf8 (piece.c_str (), grammar->partial_utf8 ));
15881592 candidates_grammar.push_back ({ i, candidates_decoded.back ().first .data (), candidates_decoded.back ().second });
15891593 }
15901594 }
15911595
1592- const auto rejects = llama_grammar_reject_candidates (grammar->rules , grammar->stacks , candidates_grammar);
1593- for (const auto & reject : rejects) {
1594- candidates->data [reject.index ].logit = -INFINITY;
1596+ for (auto reject: llama_grammar_reject_candidates (grammar->rules , grammar->stacks , candidates_grammar)) {
1597+ rejects[reject.index ] = true ;
15951598 }
1596-
1599+
15971600 auto first = candidates->data ;
15981601 auto last = first + candidates->size ;
15991602 last = std::remove_if (first, last,
1600- [&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
1603+ [&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; });
16011604 candidates->size = last - first;
16021605}
16031606
0 commit comments