Skip to content

Add backwards-compatible support for multiple EOS tokens#305

Open
hudson-ai wants to merge 29 commits intoguidance-ai:mainfrom
hudson-ai:multi_eos
Open

Add backwards-compatible support for multiple EOS tokens#305
hudson-ai wants to merge 29 commits intoguidance-ai:mainfrom
hudson-ai:multi_eos

Conversation

@hudson-ai
Copy link
Contributor

@hudson-ai hudson-ai commented Mar 6, 2026

Motivation

Models like Qwen 3/3.5 define multiple EOS token IDs in their GenerationConfig (e.g. [151645, 151643]), but llguidance only supported a single EOS token. Qwen 3 uses <|im_end|> (151645) to end turns in chat mode and <|endoftext|> (151643) as a general end-of-text marker. When llguidance is configured with only one of these (e.g. the tokenizer's default 151645), the other gets masked out. If the model tries to emit the masked EOS, it's forced to pick garbage tokens and enters an infinite repetition loop, never terminating.

Closes #253
Related: #304

Changes

  • Add eos_tokens: Vec<TokenId> to TokTrie with accessors, with_eos_tokens() builder, and validation (asserts IDs are within vocab range)
  • Update TokenParser to check the full EOS set for mask computation, token consumption, rollback, and stop detection
  • C API: LlgTokenizerInit is unchanged. New LlgTokenizerInitV2 struct (flat, with struct_size for forward compatibility) + llg_new_tokenizer_v2() function for multi-EOS support
    • struct_size enables forward compatibility: the FFI function takes a raw pointer, reads only struct_size bytes, and zero-fills any new fields the caller's header doesn't know about
    • llg_new_tokenizer_v2() validates EOS token IDs against vocab size and returns an error (not a panic) for out-of-range IDs
  • Python eos_token parameter now accepts int | list[int] across all entry points
  • Add eos_tokens getter property to Python LLTokenizer
  • Update type stubs and all Python helper modules (hf, tiktoken, llamacpp)
  • C sample tests both v1 and v2 APIs end-to-end

Usage

Python:

from llguidance.hf import from_tokenizer
# Pass multiple EOS tokens
tok = from_tokenizer(hf_tokenizer, eos_token=[151645, 151643])
# Single int still works as before
tok = from_tokenizer(hf_tokenizer, eos_token=151645)

C (v2 API):

LlgTokenizerInitV2 init = {};
init.struct_size = sizeof(init);
init.vocab_size = vocab_size;
init.tok_eos = 151645;
init.tokenize_fn = my_tokenize_fn;
// ...set other fields...
LlgToken extra_eos[] = {151643};
init.tok_eos_extra = extra_eos;
init.tok_eos_extra_count = 1;
LlgTokenizer *tok = llg_new_tokenizer_v2(&init, err, sizeof(err));

Rust:

let mut byte_tok = ByteTokenizer::from_json_str(&tokenizer_json)?;
byte_tok.set_eos_tokens(&[151645, 151643]);
let tok_env = byte_tok.into_tok_env(None)?;

API compatibility

Python: Fully backwards compatible. eos_token still accepts a single int everywhere. The only additions are that it also accepts list[int], and there's a new eos_tokens property.

C API: Fully backwards compatible. LlgTokenizerInit is identical to its pre-PR layout — zero fields added or removed. llg_new_tokenizer() is unchanged. Multi-EOS requires the new LlgTokenizerInitV2 + llg_new_tokenizer_v2(), which are purely additive.

Rust (published crates — toktrie, toktrie_hf_tokenizers, toktrie_tiktoken): Only additive changes — new methods like with_eos_tokens(), eos_tokens(), set_eos_tokens(). No existing signatures changed.

Rust (python_ext — not published): tokenv_from_llamacpp changed from eos_token: u32 to eos_tokens: &[u32]. This is technically a breaking signature change, but python_ext is only consumed internally to build the Python wheel, so no external Rust consumers are affected.

Known limitations

  • TokRxInfo.tok_eos still holds only the first (primary) EOS token. Code that reads tok_eos directly rather than going through TokTrie::eos_tokens() will only see the primary one.
  • with_info() resets eos_tokens back to vec![info.tok_eos], silently dropping extra EOS tokens. Callers that replace TokRxInfo after setting multi-EOS must re-apply with_eos_tokens().

Models like Qwen 3/3.5 define multiple EOS token IDs in their
GenerationConfig (e.g. [151645, 151643]), but llguidance only
supported a single EOS token. This caused models to enter infinite
loops when they tried to emit an EOS token that was masked out.

Changes:
- Add eos_tokens: Vec<TokenId> to TokTrie with accessors and
  with_eos_tokens() builder
- Update TokenParser to check full EOS set for mask computation,
  token consumption, rollback, and stop detection
- Add tok_eos_extra/tok_eos_extra_count to C API LlgTokenizerInit
- Python eos_token parameter now accepts int | list[int]
- Add eos_tokens getter property to Python LLTokenizer
- Update type stubs and all Python helper modules (hf, tiktoken, llamacpp)

All existing APIs remain unchanged; single EOS usage is unaffected.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds multi-EOS-token support across the tokenizer/trie layers and the parser so models that define multiple EOS IDs (e.g., Qwen/GLM) can terminate correctly without masking alternative EOS tokens.

Changes:

  • Extend TokTrie/tokenizer wrappers to carry a list of EOS token IDs (while keeping a primary EOS for compatibility).
  • Update TokenParser to treat any configured EOS token as valid for mask computation and stop detection.
  • Plumb multi-EOS through the C API and Python bindings/helpers (including typing updates).

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
toktrie_tiktoken/src/lib.rs Adds a setter to apply multiple EOS tokens to the internal TokTrie.
toktrie_hf_tokenizers/src/lib.rs Tracks extra EOS tokens and applies them when constructing the TokTrie.
toktrie/src/toktree.rs Adds eos_tokens storage + builder/accessor; adds unit tests for new behavior.
parser/src/tokenparser.rs Switches EOS handling from single token to a token set for masking/stop logic.
parser/src/ffi.rs Extends C tokenizer initialization to accept additional EOS token IDs.
parser/llguidance.h Exposes the new C init fields for extra EOS token IDs.
c_sample/c_sample.cpp Documents how to pass multiple EOS tokens via the C API.
python_ext/src/py.rs Accepts eos_token as `int
python_ext/src/llamatokenizer.rs Updates llama.cpp bridge to accept multiple EOS token IDs and apply them to TokTrie.
python/llguidance/hf.py Updates helper typing/docs to allow eos_token as `int
python/llguidance/tiktoken.py Updates helper typing/docs to allow eos_token as `int
python/llguidance/llamacpp.py Updates helper typing/docs to allow eos_token as `int
python/llguidance/_lib.pyi Updates stubs for new eos_tokens property and widened eos_token parameter type.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@sempervictus
Copy link

sempervictus commented Mar 6, 2026

@hudson-ai what would the Rust mechanism be for those when creating an llg factory from a tokenzier?

EDIT: also does this mean i should hold off on hacking-together that text_or_eos bit and just pull the next microversion or are release cycles a bit longer around here? vllm.rs is moving quick so i can always undo a hack later

hudson-ai and others added 2 commits March 5, 2026 18:14
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Document that LlgTokenizerInit must be zero-initialized before setting
fields, as new fields may be appended in future versions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@hudson-ai
Copy link
Contributor Author

@hudson-ai what would the Rust mechanism be for those when creating an llg factory from a tokenzier?

EDIT: also does this mean i should hold off on hacking-together that text_or_eos bit and just pull the next microversion or are release cycles a bit longer around here? vllm.rs is moving quick so i can always undo a hack later

If you're working from a tokenizer JSON, you'll want to do something like this:

use toktrie_hf_tokenizers::ByteTokenizer;
use llguidance::{ParserFactory, api::TopLevelGrammar};
use llguidance::toktrie::InferenceCapabilities;
use llguidance::earley::SlicedBiasComputer;

let mut byte_tok = ByteTokenizer::from_json_str(&tokenizer_json)?;
// Probably get this from the generation_config.json?
byte_tok.set_eos_tokens(&[151645, 151643]);
let tok_env = byte_tok.into_tok_env(None)?;

let factory = ParserFactory::new(
    &tok_env,
    InferenceCapabilities::default(),
    &SlicedBiasComputer::general_slices(),
)?;

// factory now produces parsers that allow both EOS tokens in masks
let grammar = TopLevelGrammar::from_lark(r#"start: /[a-z]+/"#.to_string());
let parser = factory.create_parser(grammar)?;

If you already have a TokTrie, you can call trie.with_eos_tokens(&[151645, 151643]) instead and create a TokEnv from there, but the ByteTokenizer path above is probably closest to what you're doing.

RE: holding off on the "hack" -- I think that I can reasonably get a release out in the next week or so, but no strong guarantee. Need to await some review on this PR too. But if you don't mind doing and undoing a hack, go for it 😉

sempervictus pushed a commit to sempervictus/vllm.rs that referenced this pull request Mar 6, 2026
Handle runaway model output in "normal grammar" modality masking
possible EOS tokens and producing nonsensical output once the model
has completed its normal tool-calls and chat stream:

-  guidance-ai/llguidance#304
-  guidance-ai/llguidance#305
sempervictus pushed a commit to sempervictus/vllm.rs that referenced this pull request Mar 6, 2026
Handle runaway model output in "normal grammar" modality masking
possible EOS tokens and producing nonsensical output once the model
has completed its normal tool-calls and chat stream:

-  guidance-ai/llguidance#304
-  guidance-ai/llguidance#305
hudson-ai and others added 5 commits March 6, 2026 10:53
Move tok_eos_extra/tok_eos_extra_count out of LlgTokenizerInit into a
new LlgTokenizerInitV2 struct that embeds the original as its 'base'
field. This keeps LlgTokenizerInit identical to its pre-multi-EOS
layout, avoiding any ABI break for existing C consumers.

Add llg_new_tokenizer_v2() which accepts the v2 struct. The original
llg_new_tokenizer() continues to work unchanged with single-EOS.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The leading struct_size field (set to sizeof(LlgTokenizerInitV2) by
callers) lets the library detect which fields are present. Future
fields can be appended to the struct without a v3 — callers compiled
against an older header will simply have a smaller struct_size, and
new fields will be treated as zero/default.

llg_new_tokenizer_v2() validates struct_size >= the minimum expected
size and returns an error if it's unset or too small.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replace the nested 'base: LlgTokenizerInit' member with flat copies of
all fields so C consumers write init.vocab_size instead of
init.base.vocab_size. Since v2 is the recommended struct going forward,
this avoids a permanent ergonomic tax.

Internally, from_init_v2() builds a temporary LlgTokenizerInit to
delegate to from_init(), keeping the code DRY.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add create_tokenizer_v2() and create_byte_tokenizer_v2() that exercise
LlgTokenizerInitV2 with struct_size, flat fields, and an extra EOS
token. Extract run_constraint_test() helper and run the full constraint
test with both v1 and v2 tokenizers.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

hudson-ai and others added 4 commits March 6, 2026 11:49
with_eos_tokens() now asserts all token IDs are within vocab_size,
preventing out-of-bounds panics during mask computation. This covers
all paths (C API, Python bindings, Rust API).

from_init_v2() now accepts smaller struct_size values from callers
compiled against older headers. Fields beyond what struct_size covers
are treated as zero/default. The minimum accepted size is the base
fields through slices (matching v1 + struct_size).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Revert struct_size to strict check (require >= sizeof) since the
function takes &LlgTokenizerInitV2 — Rust assumes the full struct is
readable, so accepting smaller sizes would be UB. Update docs to note
struct_size is reserved for future forward compatibility.

from_init_v2(), before calling with_eos_tokens(). This gives C callers
a graceful error instead of a panic across FFI.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Change llg_new_tokenizer_v2() to take a raw pointer instead of a Rust
reference. The function reads struct_size first, then copies only
min(struct_size, sizeof) bytes into a local zeroed struct. This means
callers compiled against an older (smaller) header genuinely work with
newer library versions — new fields default to zero.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

hudson-ai and others added 3 commits March 6, 2026 13:49
- Fix TokenizerWrapper path in py_new to apply eos_token override
- Add TokTrie::eos_token_set() that includes all EOS tokens
- Fix LLMatcher::eos_token_set() to use all EOS tokens (was singleton)
- Fix LLMatcher::consume_token_inner() to accept any EOS token
- Fix Matcher::compute_mask_or_eos() to use all EOS tokens
- Add Python tests for multi-EOS via TokenizerWrapper mock

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add test_eos_token_set_single and test_eos_token_set_multiple in toktrie
- Add test_multi_eos_mask_when_stopped in sample_parser (Matcher level)
- Simplify Python mock test to only verify TokenizerWrapper override applies

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
hudson-ai and others added 3 commits March 6, 2026 14:07
Add validate_eos_tokens() that raises PyValueError for out-of-range IDs.
Called in all Python paths (py_new, from_tiktoken, from_llamacpp) before
with_eos_tokens/set_eos_tokens to give clean errors instead of panics.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Fix min_size check to include full tok_eos field, not just its offset
- Update doc comment for llg_new_tokenizer_v2 accordingly
- Add vocab_size validation to ByteTokenizer::set_eos_token(s)
- Free token_lens/token_bytes allocations in c_sample after use

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

hudson-ai and others added 2 commits March 6, 2026 14:34
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

python_ext/src/py.rs:75

  • In the LLTokenizer("byte") path, any provided eos_token override is currently ignored (it always returns ApproximateTokEnv::single_byte_env()). Since the Python API now accepts int | list[int] for eos_token, this should either be applied here (e.g., by cloning the underlying trie and calling with_eos_tokens() before wrapping it in an ApproximateTokEnv) or explicitly rejected with a clear error to avoid silent misconfiguration.
        let tok_env: TokEnv = if let Ok(tokenizer_str) = tokenizer.extract::<String>() {
            if tokenizer_str == "byte" {
                ApproximateTokEnv::single_byte_env()
            } else {

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@hudson-ai
Copy link
Contributor Author

Ok after that tremendously long back and forth... @riedgar-ms would you mind reading the C part of this and giving me a sanity check?

@sempervictus
Copy link

That multitoken ID grammar is looking pretty elegant right about now 😉

@hudson-ai
Copy link
Contributor Author

That multitoken ID grammar is looking pretty elegant right about now 😉

Ugh I know. The actual core changes are pretty small -- just lots of API entry-points to be careful of (including "nice" error handling). I've actually considered making this PR more than once over the last few months and have balked every time because I didn't want to make breaking FFI changes that necessitated a major version bump. I think that I got a handle on that now though...

sempervictus pushed a commit to sempervictus/vllm.rs that referenced this pull request Mar 7, 2026
Reliance on single/guessed special tokens only goes so far when we
use constrained outputs because masking-out a potential EOS token
results in infinite generation: we have to account for all possible
candidates in the mask which could normally end generation.

Add SpecialTokens idiomatic extractor as a starting point for this
work and utilize it to feed all EOS tokens to the grammar-building
routines. Add the binary for this library element to examples/ for
@guoqingbao and other developers to have rapid access to what the
SpecialTokens struct actually extracts from any tokenizer.json
provided in ARGV0 or from ./tokenizer.json if none are provided.

Improve XML tool-sled generation. Remaining issue is potential of
XML content within the XML envelope and no ability to mask possibly
infinite strings as anything but infinite due to look-ahead and lazy
regex tricks from interpreted languages not actually compiling to a
finite mask. Use a simple matcher for now, enable env-override by
the user while this gets sorted out (if possible) and critically
enable the grammar generator to honor tool parser override at the
CLI such that `--enforce-parser qwen` produces JSON-constrained
schemas which the parser can then consume.

XML finite masking tracked under:
- guidance-ai/llguidance#306

Multiple EOS token concerns (handled in grammar) under:
- guidance-ai/llguidance#304
- guidance-ai/llguidance#305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this a wrapper for create_tokenizer_v2()? Isn't the code largely the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the hood, each function is using a different init struct, and consolidating would effectively remove the "living documentation" (slash test) of the old API. I think that removing the commented-out block is a good idea though.

Thoughts?

hudson-ai and others added 5 commits March 10, 2026 09:16
Replace raw new[]/delete[] allocations for token_lens and token_bytes
with std::vector in both create_tokenizer_v2() and create_tokenizer().
This is exception-safe and avoids manual memory management.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Instead of separate tok_eos + extra_eos_tokens parameters, accept a
single std::vector<uint32_t> where [0] is the primary EOS and any
remaining entries are extra EOS tokens. Cleaner C++ API while still
mapping naturally to the underlying C struct fields.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The v2 API now has a real working example in create_tokenizer_v2()
above, so the inline commented-out snippet is redundant.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Support Multiple EOS Tokens for Accurate Termination in Structured Output Generation

4 participants