Skip to content

Commit 263264a

Browse files
authored
fix: additional special tokens being replaced (#517)
* fix additional special tokens being replaced Signed-off-by: Dushyant Behl <[email protected]> * fix lint Signed-off-by: Dushyant Behl <[email protected]> --------- Signed-off-by: Dushyant Behl <[email protected]>
1 parent cc7a9e8 commit 263264a

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

scripts/offline_data_processing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ def get_processed_dataset(
132132

133133
if special_tokens_dict:
134134
logger.info("Adding special tokens: %s", special_tokens_dict)
135-
tokenizer.add_special_tokens(special_tokens_dict)
135+
tokenizer.add_special_tokens(
136+
special_tokens_dict=special_tokens_dict,
137+
replace_additional_special_tokens=False,
138+
)
136139

137140
# Process data using the provided arguments and tokenizer
138141
logger.info("Calling process_dataargs to format datasets.")

tests/utils/test_embedding_resize.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from transformers import AutoModelForCausalLM, AutoTokenizer
2020
import torch
2121

22+
# First Party
23+
from tests.artifacts.testdata import CUSTOM_TOKENIZER_TINYLLAMA
24+
2225
# Local
2326
from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize
2427

@@ -106,6 +109,60 @@ def test_resize_with_special_tokens():
106109
assert output is not None
107110

108111

112+
def test_special_tokens_before_and_after():
113+
"""Test if additional special tokens added do not replace existing tokens"""
114+
input_text = INPUT_TEXT
115+
tokenizer = AutoTokenizer.from_pretrained(CUSTOM_TOKENIZER_TINYLLAMA)
116+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
117+
118+
input_tokenizer_len = len(tokenizer.get_vocab())
119+
addn_spl_tokens_before = tokenizer.special_tokens_map.get(
120+
"additional_special_tokens"
121+
)
122+
assert (
123+
len(addn_spl_tokens_before) > 0
124+
), "this test needs tokenizer special tokens to not be empty before testing"
125+
126+
special_tokens_dict = {"sep_token": "<SEP>", "pad_token": "<PAD>"}
127+
addn_spl_tokens_added = ["<NotSeenTokenA>", "<NotSeenTokenB>", "<NotSeenTokenC>"]
128+
special_tokens_dict["additional_special_tokens"] = addn_spl_tokens_added
129+
130+
resize_result = tokenizer_and_embedding_resize(
131+
special_tokens_dict=special_tokens_dict,
132+
tokenizer=tokenizer,
133+
model=model,
134+
multiple_of=1,
135+
)
136+
137+
output_tokenizer_len = len(tokenizer.get_vocab())
138+
addn_spl_tokens_before.extend(addn_spl_tokens_added)
139+
expected_addn_special_tokens = addn_spl_tokens_before
140+
expected_embedding_size = input_tokenizer_len + len(addn_spl_tokens_added) + 2
141+
addn_spl_tokens_after = tokenizer.special_tokens_map.get(
142+
"additional_special_tokens"
143+
)
144+
145+
assert "<SEP>" in tokenizer.get_vocab()
146+
assert "<PAD>" in tokenizer.get_vocab()
147+
assert output_tokenizer_len == expected_embedding_size
148+
assert resize_result["num_new_tokens"] == output_tokenizer_len - input_tokenizer_len
149+
assert resize_result["new_embedding_size"] == expected_embedding_size
150+
151+
assert len(addn_spl_tokens_after) == len(
152+
expected_addn_special_tokens
153+
), "length of the additional special tokens after must equal length before plus added tokens"
154+
155+
for tok in expected_addn_special_tokens:
156+
assert (
157+
tok in addn_spl_tokens_after
158+
), "additional special tokens added are not in tokenizer"
159+
160+
output = _inference(
161+
tokenizer=tokenizer, model=model, input_text=input_text, max_new_tokens=20
162+
)
163+
assert output is not None
164+
165+
109166
def test_no_resize_when_no_special_tokens():
110167
input_text = INPUT_TEXT
111168
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tuning/utils/tokenizer_data_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def tokenizer_and_embedding_resize(
3535
Return:
3636
dict: Metadata on number of added tokens
3737
"""
38-
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
38+
num_new_tokens = tokenizer.add_special_tokens(
39+
special_tokens_dict=special_tokens_dict, replace_additional_special_tokens=False
40+
)
3941
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
4042
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
4143
model.resize_token_embeddings(embedding_size)

0 commit comments

Comments
 (0)