Skip to content

Conversation

amirai21
Copy link

@amirai21 amirai21 commented Oct 8, 2025

The JambaModel implementation at convert_hf_to_gguf.py was incorrectly constructing its vocab using the gpt-2 tokenizer logic when no SentencePiece model was present (i.e., tokenizer.json path). Jamba actually uses a llama tokenizer, not gpt-2.

This change updates the vocab build path to use the correct llama tokenizer for non-SentencePiece Jamba models. Also includes several small adjustments within the Jamba llama based tokenizer construction.

No changes are expected for other model types.

Testing
Verified with local conversion of Jamba GGUF model (tokenizer.json mode) and confirmed generated vocab matches the llama tokenizer layout. SentencePiece mode GGUF was also verified and it remains unaffected.

@amirai21 amirai21 requested a review from CISC as a code owner October 8, 2025 10:20
@github-actions github-actions bot added the python python script changes label Oct 8, 2025
Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works correctly, basically you're recreating a SPM vocab without scores, have you checked that tokenization is identical to AutoTokenizers (you can use convert_hf_to_gguf_update.py to generate test files and test with test-tokenizer-0)?

Comment on lines 5915 to 5919
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"
return "default"

Copy link
Collaborator

@CISC CISC Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this method, only for BPE.

Comment on lines +5934 to +5936
assert max(vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
assert max(vocab.values()) < vocab_size

tokens.append(token)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre(tokpre)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_tokenizer_pre("default")

Comment on lines +5975 to +5976
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No scores?

self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)

Comment on lines +5953 to +5955
token = tokenizer.decode(
tokenizer.encode(token, add_special_tokens=False)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
token = tokenizer.decode(
tokenizer.encode(token, add_special_tokens=False)
)
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))

Comment on lines +5957 to +5959
logger.info(
f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(
f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer"
)
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")

Comment on lines +5961 to +5963
if added_tokens_decoder[i].special or self.does_token_look_special(
token
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if added_tokens_decoder[i].special or self.does_token_look_special(
token
):
if added_tokens_decoder[i].special or self.does_token_look_special(token):

Comment on lines +5929 to +5931
tokenizer = AutoTokenizer.from_pretrained(
self.dir_model, trust_remote_code=True
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tokenizer = AutoTokenizer.from_pretrained(
self.dir_model, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants