@@ -1593,7 +1593,12 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15931593 for (const auto & reject : rejects) {
15941594 candidates->data [reject.index ].logit = -INFINITY;
15951595 }
1596-
1596+
1597+ auto first = candidates->data ;
1598+ auto last = first + candidates->size ;
1599+ last = std::remove_if (first, last,
1600+ [&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
1601+ candidates->size = last - first;
15971602}
15981603
15991604void sample_guidance (struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
@@ -1643,15 +1648,30 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
16431648
16441649 llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
16451650
1646- if (grammar != nullptr ) {
1647- sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1648- }
1649-
16501651 // dry always first as logits cannot be resorted
16511652 sample_dry (n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1652-
1653+
16531654 // prefilter to top 3k tokens for improved speed
1655+ bool use_grammar = grammar != nullptr ;
1656+ size_t n_pre_cull = candidates_p.size ;
1657+
16541658 sample_top_k (&candidates_p, 3000 );
1659+
1660+ if (use_grammar) {
1661+
1662+ (debugmode == 1 && printf (" \n Grammar sampling %zu candidates.\n " , candidates_p.size ));
1663+ sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1664+ (debugmode == 1 && printf (" \n Grammar returned %zu candidates.\n " , candidates_p.size ));
1665+
1666+ // if top_k 3000 doesn't contain a valid candidate for this grammar, try again pre-cull
1667+ if (candidates_p.size <= 0 ) {
1668+ candidates_p.size = n_pre_cull;
1669+ (debugmode == 1 && printf (" \n Re-sampling grammar with %zu pre-cull tokens.\n " , candidates_p.size ));
1670+ sample_grammar (file_format, n_vocab, &candidates_p, grammar);
1671+ (debugmode == 1 && printf (" \n Grammar returned %zu candidates.\n " , candidates_p.size ));
1672+ sample_top_k (&candidates_p, 3000 );
1673+ }
1674+ }
16551675
16561676 if (mirostat == 1 || mirostat == 2 )
16571677 {
@@ -1745,7 +1765,6 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct
17451765 const auto decoded = decode_utf8 (piece.c_str (), grammar->partial_utf8 );
17461766 const auto & code_points = decoded.first ;
17471767 for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1748- auto prev_stacks = grammar->stacks ;
17491768 llama_grammar_accept (grammar, *it);
17501769 }
17511770 grammar->partial_utf8 = decoded.second ;
@@ -3941,6 +3960,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
39413960 }
39423961
39433962 if (grammar != nullptr ) {
3963+ (debugmode == 1 && printf (" \n Grammar attempting to accept token...\n " ));
39443964 grammar_accept_token (file_format, n_vocab, grammar, id);
39453965 }
39463966
0 commit comments