Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 31 additions & 92 deletions genlm_backend/tokenization/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ def get_byte_vocab(tokenizer):
if hasattr(tokenizer, "sp_model"):
return get_byte_tokens_from_sp(tokenizer)

# Try through token encoding.
try:
return get_byte_tokens_by_encoding_token_strings(tokenizer)
except Exception:
# warnings.warn(f"Could not decode vocabulary through string encoding: {e!r}")
pass

# Try using GPT2 byte decoder.
try:
byte_decoder = _get_default_byte_decoder()
Expand All @@ -67,83 +60,25 @@ def get_byte_vocab(tokenizer):
def get_byte_tokens_from_byte_decoder(tokenizer, byte_decoder):
"""Convert tokens to bytes using a byte decoder mapping.

Special tokens are handled by directly encoding their string representation.

Args:
tokenizer: A Hugging Face tokenizer instance
byte_decoder (dict): Dictionary mapping characters to bytes

Returns:
byte_tokens (list[byte]): List of byte representations for each token
"""
special_tokens_map = {v: k for k, v in tokenizer.get_added_vocab().items()}
byte_tokens = [
bytes([byte_decoder[b] for b in tokenizer.convert_ids_to_tokens(i)])
if i not in special_tokens_map
else special_tokens_map[i].encode()
for i in range(len(tokenizer))
]
return byte_tokens


def get_byte_tokens_by_encoding_token_strings(tokenizer):
"""Convert tokens to bytes by encoding token strings directly.

This function attempts to convert each token in the vocabulary to its byte representation
by directly encoding the token strings. It handles special tokens separately and has
multiple fallback strategies for encoding regular tokens:

1. For special tokens, uses the string representation from the tokenizer's added vocab
2. For regular tokens:
a. If the token is already bytes, uses it directly
b. If the token is a string and the tokenizer has convert_tokens_to_string:
- Converts single token to string
- Verifies roundtrip encoding matches original token ID
- Falls back to byte decoder if roundtrip fails
c. If the token is a string without convert_tokens_to_string:
- Directly encodes the token string

Args:
tokenizer: A Hugging Face tokenizer instance.

Returns:
byte_tokens (list[byte]): List of byte representations for each token in the vocabulary.

Raises:
ValueError: If token encoding fails (roundtrip produces multiple tokens), or if
a token has an unexpected type (not str or bytes).
"""
byte_tokens = [b""] * len(tokenizer)
special_tokens_map = {
id: token for token, id in tokenizer.get_added_vocab().items()
}
byte_encoder = _bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(len(tokenizer)):
if i in special_tokens_map:
byte_coded = special_tokens_map[i].encode()
else:
token = tokenizer.convert_ids_to_tokens(i)
if isinstance(token, bytes):
byte_coded = token
elif isinstance(token, str):
if hasattr(tokenizer, "convert_tokens_to_string"):
token_str = tokenizer.convert_tokens_to_string([token])
encoded_str = tokenizer.encode(token_str)
if len(encoded_str) != 1:
raise ValueError(
f"Round-trip encoding of tokens [{token}] failed! Got {encoded_str}"
)
roundtrip_id = encoded_str[0]
if roundtrip_id == i:
byte_coded = token_str.encode()
else:
byte_coded = bytes([byte_decoder[c] for c in token])
else:
byte_coded = token.encode()
else:
raise ValueError(f"Unexpected token type: {type(token)}")
byte_tokens[i] = byte_coded

return byte_tokens


def get_byte_tokens_from_sp(tokenizer):
"""Convert tokens to their byte representations using a SentencePiece model.

Expand Down Expand Up @@ -195,7 +130,8 @@ def check_byte_decoder(tokenizer, byte_decoder):


def _check_byte_decoder_has_all_bytes(tokenizer, byte_decoder):
"""Verify byte decoder contains mappings for all bytes in vocabulary.
"""Verify byte decoder contains mappings for all bytes in vocabulary,
excluding special tokens.

Args:
tokenizer: A Hugging Face tokenizer instance
Expand All @@ -204,8 +140,11 @@ def _check_byte_decoder_has_all_bytes(tokenizer, byte_decoder):
Raises:
ByteDecoderError: If byte decoder is missing required bytes
"""
special_tokens = tokenizer.get_added_vocab().keys()
all_bytes = set()
for x in tokenizer.get_vocab().keys():
if x in special_tokens:
continue
for y in x:
all_bytes.add(y)
if not set(byte_decoder.keys()) >= all_bytes:
Expand All @@ -227,7 +166,7 @@ def _check_complex_roundtrip(tokenizer, byte_decoder):
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨"
reconstructed = b""
try:
input_ids = tokenizer(s)["input_ids"]
input_ids = tokenizer(s, add_special_tokens=False)["input_ids"]
for i in input_ids:
nxt_bytes = []
token_str = tokenizer.convert_ids_to_tokens(i)
Expand All @@ -252,26 +191,26 @@ def _check_complex_roundtrip(tokenizer, byte_decoder):
)


def _bytes_to_unicode():
"""Create a mapping from bytes to Unicode characters.

Returns:
(dict): Mapping from byte values to Unicode characters
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
# def _bytes_to_unicode():
# """Create a mapping from bytes to Unicode characters.
#
# Returns:
# (dict): Mapping from byte values to Unicode characters
# """
# bs = (
# list(range(ord("!"), ord("~") + 1))
# + list(range(ord("¡"), ord("¬") + 1))
# + list(range(ord("®"), ord("ÿ") + 1))
# )
# cs = bs[:]
# n = 0
# for b in range(256):
# if b not in bs:
# bs.append(b)
# cs.append(256 + n)
# n += 1
# cs = [chr(n) for n in cs]
# return dict(zip(bs, cs))


def _get_default_byte_decoder():
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ pytest
pytest-asyncio
pytest-benchmark
arsenal @ git+https://github.com/timvieira/arsenal
datasets
viztracer
hypothesis
IPython
124 changes: 54 additions & 70 deletions tests/test_vocabulary.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from functools import wraps
from datasets import load_dataset
from transformers import AutoTokenizer

from genlm_backend.tokenization import decode_vocab
from genlm_backend.tokenization.vocab import assert_roundtrip_bytes
from hypothesis import given, strategies as st, settings


def skip_if_gated(f):
Expand All @@ -18,91 +18,75 @@ def wrapper(*args, **kwargs):
return wrapper


@pytest.fixture
def test_text():
text = "\n".join(load_dataset("wikitext", "wikitext-2-raw-v1")["test"]["text"])
return text[:5000]
tokenizer_cache = {}


@skip_if_gated
def test_gpt2(test_text):
# Uses byte decoder
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
def load_tokenizer(name, use_fast):
if (name, use_fast) in tokenizer_cache:
return tokenizer_cache[(name, use_fast)]
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast)
tokenizer_cache[(name, use_fast)] = tokenizer
return tokenizer


@skip_if_gated
def test_llama3(test_text):
# Uses GPT2 byte decoder
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3-8B", use_fast=True
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3-8B", use_fast=False
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_gpt2(text, is_fast):
tokenizer = load_tokenizer("gpt2", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)


@skip_if_gated
def test_codellama(test_text):
# Uses SentencePiece method
tokenizer = AutoTokenizer.from_pretrained(
"codellama/CodeLlama-7b-Instruct-hf", use_fast=True
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained(
"codellama/CodeLlama-7b-Instruct-hf", use_fast=False
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_llama3(text, is_fast):
tokenizer = load_tokenizer("meta-llama/Meta-Llama-3-8B", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)


@skip_if_gated
def test_gemma(test_text):
# Uses SentencePiece method
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b", use_fast=True)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_codellama(text, is_fast):
tokenizer = load_tokenizer("codellama/CodeLlama-7b-Instruct-hf", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b", use_fast=False)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

@skip_if_gated
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_gemma(text, is_fast):
tokenizer = load_tokenizer("google/gemma-7b", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)


@skip_if_gated
def _test_phi(test_text): # Currently fails.
# Has a byte decoder, but it is missing bytes. Uses GPT2 byte decoder.
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", use_fast=True)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_phi(text, is_fast):
tokenizer = load_tokenizer("microsoft/phi-2", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", use_fast=False)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

@skip_if_gated
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_mistral(text, is_fast):
tokenizer = load_tokenizer("mistralai/Mistral-7B-Instruct-v0.3", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)


@skip_if_gated
def test_mistral(test_text):
# Uses SentencePiece method
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.3", use_fast=True
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)

tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.3", use_fast=False
)
(byte_vocab, str_vocab) = decode_vocab(tokenizer)
assert_roundtrip_bytes(test_text, tokenizer, byte_vocab)
@settings(deadline=None)
@given(text=st.text(min_size=1, max_size=500), is_fast=st.booleans())
def test_deepseek_r1_unsloth(text, is_fast):
tokenizer = load_tokenizer("unsloth/DeepSeek-R1-Distill-Llama-8B", is_fast)
byte_vocab, _ = decode_vocab(tokenizer)
assert_roundtrip_bytes(text, tokenizer, byte_vocab)