Skip to content

Commit 9494710

Browse files
authored
fix tokenizer overrides w gemma3 (axolotl-ai-cloud#2488)
* fix tokenizer overrides w gemma3 * fix offline wrapping
1 parent de451f9 commit 9494710

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

src/axolotl/utils/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,13 @@ def modify_tokenizer_files(
283283
raise ValueError(
284284
f"Token ID {token_id} not found in added_tokens"
285285
)
286+
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
287+
for token_id, new_value in token_id_mappings.items():
288+
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
289+
if entry_id == token_id:
290+
del tokenizer_data["model"]["vocab"][entry_val]
291+
tokenizer_data["model"]["vocab"][new_value] = token_id
292+
break
286293

287294
# Write the updated tokenizer data back
288295
with open(tokenizer_path, "w", encoding="utf-8") as f:

tests/conftest.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from transformers import AutoTokenizer
2020

2121
from tests.hf_offline_utils import (
22-
disable_hf_offline,
2322
enable_hf_offline,
2423
hf_offline_context,
2524
)
@@ -50,7 +49,6 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
5049

5150

5251
@retry_on_request_exceptions(max_retries=3, delay=5)
53-
@disable_hf_offline
5452
def snapshot_download_w_retry(*args, **kwargs):
5553
"""
5654
download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local
@@ -62,7 +60,8 @@ def snapshot_download_w_retry(*args, **kwargs):
6260
return snapshot_download(*args, **kwargs)
6361
except LocalEntryNotFoundError:
6462
pass
65-
return snapshot_download(*args, **kwargs)
63+
with hf_offline_context(False):
64+
return snapshot_download(*args, **kwargs)
6665

6766

6867
@pytest.fixture(scope="session", autouse=True)
@@ -265,6 +264,16 @@ def download_mistral_7b_model_fixture():
265264
)
266265

267266

267+
@pytest.fixture(scope="session", autouse=True)
268+
def download_gemma3_4b_model_fixture():
269+
# download the tokenizer only
270+
snapshot_download_w_retry(
271+
"mlx-community/gemma-3-4b-it-8bit",
272+
repo_type="model",
273+
allow_patterns=["*token*", "config.json"],
274+
)
275+
276+
268277
@pytest.fixture(scope="session", autouse=True)
269278
def download_gemma_2b_model_fixture():
270279
# download the tokenizer only

tests/hf_offline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def hf_offline_context(hf_hub_offline):
9595
"""
9696
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
9797
os.environ["HF_HUB_OFFLINE"] = str(hf_hub_offline)
98-
reload_modules(True)
98+
reload_modules(bool(hf_hub_offline))
9999
yield
100100
# Restore the original value of HF_HUB_OFFLINE environment variable
101101
if original_hf_offline is not None:

tests/test_tokenizers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,34 @@ def test_added_tokens_overrides(self, temp_dir):
110110
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
111111
128042
112112
]
113+
assert (
114+
tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
115+
)
116+
117+
@enable_hf_offline
118+
def test_added_tokens_overrides_gemma3(self, temp_dir):
119+
cfg = DictDefault(
120+
{
121+
# use with tokenizer that has reserved_tokens in added_tokens
122+
"tokenizer_config": "mlx-community/gemma-3-4b-it-8bit",
123+
"added_tokens_overrides": {
124+
256001: "RANDOM_OVERRIDE_1",
125+
256002: "RANDOM_OVERRIDE_2",
126+
},
127+
"output_dir": temp_dir,
128+
}
129+
)
130+
131+
tokenizer = load_tokenizer(cfg)
132+
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
133+
256001
134+
]
135+
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
136+
256002
137+
]
138+
assert (
139+
tokenizer.decode([256001, 256002]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
140+
)
113141

114142
@enable_hf_offline
115143
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):

0 commit comments

Comments
 (0)