Skip to content

Commit 304b179

Browse files
add EOS token to dataset (#15)
* Adding dataset mapping prep to get EOS token added at end of sequence * merge main Signed-off-by: Sukriti-Sharma4 <[email protected]> * minor fix Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix:encoding Signed-off-by: Sukriti-Sharma4 <[email protected]> * remove hardcoded padding side Signed-off-by: Sukriti-Sharma4 <[email protected]> * code cleanup and addition of TODO comments Signed-off-by: Sukriti-Sharma4 <[email protected]> --------- Signed-off-by: Sukriti-Sharma4 <[email protected]> Co-authored-by: RAGHU KIRAN GANTI <[email protected]>
1 parent 89d43c8 commit 304b179

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

tuning/sft_trainer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast
1+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer
22
import fire
33
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
44
import transformers
@@ -77,21 +77,29 @@ def train(
7777
tokenizer = transformers.AutoTokenizer.from_pretrained(
7878
model_args.model_name_or_path,
7979
cache_dir=train_args.cache_dir,
80-
padding_side="right",
8180
use_fast = True
8281
)
82+
83+
# TODO: understand if we need to hardcode these here or just use defaults in model
8384
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
8485
tokenizer.add_special_tokens({
8586
"bos_token": "<s>",
8687
"eos_token": "</s>",
8788
"unk_token": "<unk>",
8889
"pad_token": "<pad>",
8990
})
90-
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
91+
elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, GPT2Tokenizer):
9192
tokenizer.add_special_tokens({
9293
"pad_token": "<pad>",
9394
})
94-
95+
96+
"""TODO: near term - how response template ids are parsed out needs to be cleaned.
97+
The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found.
98+
We will create issue to clean this out after we discuss data formats and collators we will support
99+
"""
100+
response_template_ids = tokenizer.encode(data_args.response_template, add_special_tokens=False)[2:]
101+
# TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
102+
# as in current main. We need to change name of this parameter we expose to users.
95103
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
96104
logger.info(f"Model max length {model_max_length}")
97105
if train_args.model_max_length > tokenizer.model_max_length:
@@ -112,6 +120,8 @@ def train(
112120
logger.warning("UNK token set to default, missing in tokenizer")
113121
special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN
114122

123+
# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
124+
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
115125
tokenizer_data_utils.tokenizer_and_embedding_resize(
116126
special_tokens_dict=special_tokens_dict,
117127
tokenizer=tokenizer,
@@ -120,7 +130,8 @@ def train(
120130

121131
# load the data by parsing JSON
122132
json_dataset = datasets.load_dataset('json', data_files=data_args.data_path)
123-
logger.info(f"Dataset length is {len(json_dataset['train'])}")
133+
formatted_dataset = json_dataset['train'].map(lambda example : {f"{data_args.dataset_text_field}" : example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token})
134+
logger.info(f"Dataset length is {len(formatted_dataset)}")
124135

125136
aim_callback = get_aimstack_callback()
126137
callbacks=[aim_callback,PeftSavingCallback()]
@@ -138,15 +149,14 @@ def train(
138149
if data_args.dataset_text_field is None:
139150
logger.error("Error, dataset_text_field is None, needs to be set for training")
140151
exit(-1)
141-
142-
response_template_ids = tokenizer.encode(data_args.response_template, add_special_tokens=False)[2:]
152+
143153
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX)
144154
packing = False
145155

146156
trainer = SFTTrainer(
147157
model=model,
148158
tokenizer=tokenizer,
149-
train_dataset=json_dataset['train'],
159+
train_dataset=formatted_dataset,
150160
packing=packing,
151161
data_collator=data_collator,
152162
dataset_text_field=data_args.dataset_text_field,

0 commit comments

Comments
 (0)