Skip to content

Commit 82113e4

Browse files
committed
add chat template data handler
Signed-off-by: Dushyant Behl <[email protected]>
1 parent 1ab11e7 commit 82113e4

File tree

9 files changed

+92
-36
lines changed

9 files changed

+92
-36
lines changed

tests/artifacts/testdata/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
6161
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
6262
)
63-
CHAT_DATA_SINGLE_TURN = os.path.join(JSON_DATA_DIR, "single_turn_chat.jsonl")
64-
CHAT_DATA_MULTI_TURN = os.path.join(JSON_DATA_DIR, "multi_turn_chat.jsonl")
63+
CHAT_DATA_SINGLE_TURN = os.path.join(JSONL_DATA_DIR, "single_turn_chat.jsonl")
64+
CHAT_DATA_MULTI_TURN = os.path.join(JSONL_DATA_DIR, "multi_turn_chat.jsonl")
6565
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
6666
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
6767

tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/added_tokens.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"<pad>": 32003,
23
"<|assistant|>": 32001,
34
"<|system|>": 32002,
45
"<|user|>": 32000

tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/special_tokens_map.json

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,8 @@
11
{
22
"additional_special_tokens": [
3-
{
4-
"content": "<|user|>",
5-
"lstrip": false,
6-
"normalized": false,
7-
"rstrip": false,
8-
"single_word": false
9-
},
10-
{
11-
"content": "<|assistant|>",
12-
"lstrip": false,
13-
"normalized": false,
14-
"rstrip": false,
15-
"single_word": false
16-
},
17-
{
18-
"content": "<|system|>",
19-
"lstrip": false,
20-
"normalized": false,
21-
"rstrip": false,
22-
"single_word": false
23-
}
3+
"<|user|>",
4+
"<|assistant|>",
5+
"<|system|>"
246
],
257
"bos_token": {
268
"content": "<s>",
@@ -36,6 +18,13 @@
3618
"rstrip": false,
3719
"single_word": false
3820
},
21+
"pad_token": {
22+
"content": "<pad>",
23+
"lstrip": false,
24+
"normalized": false,
25+
"rstrip": false,
26+
"single_word": false
27+
},
3928
"unk_token": {
4029
"content": "<unk>",
4130
"lstrip": false,

tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@
5656
"rstrip": false,
5757
"normalized": false,
5858
"special": true
59+
},
60+
{
61+
"id": 32003,
62+
"content": "<pad>",
63+
"single_word": false,
64+
"lstrip": false,
65+
"rstrip": false,
66+
"normalized": false,
67+
"special": true
5968
}
6069
],
6170
"normalizer": {

tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer_config.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@
5050
"rstrip": false,
5151
"single_word": false,
5252
"special": true
53+
},
54+
"32003": {
55+
"content": "<pad>",
56+
"lstrip": false,
57+
"normalized": false,
58+
"rstrip": false,
59+
"single_word": false,
60+
"special": true
5361
}
5462
},
5563
"additional_special_tokens": [
@@ -62,7 +70,7 @@
6270
"eos_token": "</s>",
6371
"legacy": true,
6472
"model_max_length": 2048,
65-
"pad_token": null,
73+
"pad_token": "<pad>",
6674
"sp_model_kwargs": {},
6775
"tokenizer_class": "LlamaTokenizer",
6876
"unk_token": "<unk>",

tests/test_sft_trainer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def test_run_chat_style_ft(dataset_path):
888888
train_args = copy.deepcopy(TRAIN_ARGS)
889889
train_args.output_dir = tempdir
890890

891-
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)
891+
sft_trainer.train(model_args, data_args, train_args)
892892

893893
# validate full ft configs
894894
_validate_training(tempdir)
@@ -917,7 +917,20 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
917917
{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}\
918918
{% endif %}\
919919
{% endfor %}"
920+
data_args.response_template = "<|assistant|>"
920921
data_args.instruction_template = "<|user|>"
922+
data_args.dataset_text_field = "new_formatted_field"
923+
924+
handler_kwargs = {"dataset_text_field": data_args.dataset_text_field}
925+
kwargs = {
926+
"fn_kwargs": handler_kwargs,
927+
"batched": False,
928+
"remove_columns": "all",
929+
}
930+
931+
handler_config = DataHandlerConfig(
932+
name="apply_tokenizer_chat_template", arguments=kwargs
933+
)
921934

922935
model_args = copy.deepcopy(MODEL_ARGS)
923936
model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA
@@ -932,13 +945,13 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
932945
data = yaml.safe_load(f)
933946
datasets = data["datasets"]
934947
for i, d in enumerate(datasets):
935-
d["data_paths"][0] = datafiles[i]
948+
d["data_paths"] = [datafiles[i]]
936949
# Basic chat datasets don't need data handling
937-
del d["data_handlers"]
950+
d["data_handlers"] = [asdict(handler_config)]
938951
yaml.dump(data, temp_yaml_file)
939952
data_args.data_config_path = temp_yaml_file.name
940953

941-
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)
954+
sft_trainer.train(model_args, data_args, train_args)
942955

943956
# validate full ft configs
944957
_validate_training(tempdir)

tuning/config/configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ class DataArguments:
105105
chat_template: str = field(
106106
default=None,
107107
metadata={
108-
"help": "chat template to use for tokenization. \
109-
No need to pass this if the tokenizer already has a chat_template \
110-
if passed, it will overwrite tokenizer.chat_template if it exists"
108+
"help": "Chat template to use for tokenization. \
109+
No need to pass this if the tokenizer already has a chat_template. \
110+
If passed, it will overwrite tokenizer.chat_template if it exists."
111111
},
112112
)
113113
instruction_template: str = field(

tuning/data/data_handlers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,23 @@ def replace_text(match_obj):
137137
}
138138

139139

140+
def apply_tokenizer_chat_template(
141+
element: Dict[str, str],
142+
tokenizer: AutoTokenizer,
143+
dataset_text_field: str,
144+
**kwargs,
145+
):
146+
if tokenizer.chat_template is None:
147+
raise ValueError("Tokenizer does not contain tokenizer.chat_template\
148+
please pass data_args.chat_template")
149+
return {
150+
f"{dataset_text_field}": tokenizer.apply_chat_template(element, tokenize=False)
151+
}
152+
153+
140154
AVAILABLE_DATA_HANDLERS = {
141155
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking,
142156
"apply_dataset_formatting": apply_dataset_formatting,
143157
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
158+
"apply_tokenizer_chat_template": apply_tokenizer_chat_template,
144159
}

tuning/data/setup_dataprocessor.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,25 @@ def _get_dataset_formatting_handlers(data_args, packing):
151151
return [handler], dataset_text_field
152152

153153

154+
### Default Format 3
155+
def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
156+
157+
if data_args.dataset_text_field is None:
158+
data_args.dataset_text_field = "new_formatted_field"
159+
160+
fn_kwargs = {}
161+
fn_kwargs["dataset_text_field"] = data_args.dataset_text_field
162+
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
163+
164+
kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}
165+
166+
handlers = [
167+
DataHandlerConfig("apply_tokenizer_chat_template", arguments=kwargs),
168+
]
169+
170+
return handlers, data_args.dataset_text_field
171+
172+
154173
### Default Data format
155174
def _get_default_dataset_handlers(data_args, tokenizer_kwargs):
156175

@@ -236,15 +255,17 @@ def _process_raw_data_args(
236255
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
237256
data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized)
238257
)
258+
elif data_args.instruction_template and data_args.response_template:
259+
# Data Format 2: Chat dataset with instruction and response template
260+
# We don't do processing for chat dataset
261+
handlers, dataset_text_field = _get_chat_dataset_handlers(
262+
data_args, tokenizer_kwargs
263+
)
239264
elif data_args.data_formatter_template or data_args.dataset_text_field:
240-
# Data Format 2: Single Sequence Dataset
265+
# Data Format 3: Single Sequence Dataset
241266
handlers, dataset_text_field = _get_dataset_formatting_handlers(
242267
data_args, packing
243268
)
244-
elif data_args.instruction_template and data_args.response_template:
245-
# Data Format 3: Chat dataset with instruction and response template
246-
# We don't do processing for chat dataset
247-
handlers, dataset_text_field = [], None
248269
else:
249270
# Default Data Format: Dataset with Input/Output Fields
250271
handlers, dataset_text_field = _get_default_dataset_handlers(

0 commit comments

Comments
 (0)