Skip to content

Commit f44b370

Browse files
committed
Add multi turn chat support.
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 4168c87 commit f44b370

File tree

6 files changed

+67
-19
lines changed

6 files changed

+67
-19
lines changed

.pylintrc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ ignored-parents=
280280
# Maximum number of arguments for function / method.
281281
max-args=5
282282

283-
# Maximum number of attributes for a class (see R0902).
284-
max-attributes=7
283+
# Maximum number of attributes for a class (custom).
284+
max-attributes=10
285285

286286
# Maximum number of boolean expressions in an if statement (see R0916).
287287
max-bool-expr=5

tests/data/test_data_preprocessing_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,11 @@ def test_get_data_collator(
320320
"""Ensure that the correct collator type is fetched based on the data args"""
321321
collator = get_data_collator(
322322
packing,
323-
response_template,
324323
AutoTokenizer.from_pretrained(MODEL_NAME),
325-
is_pretokenized_dataset(formatted_train_dataset),
326324
max_seq_length,
325+
response_template,
326+
None,
327+
is_pretokenized_dataset(formatted_train_dataset),
327328
)
328329
assert isinstance(collator, expected_collator)
329330

tuning/config/configs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@ class DataArguments:
102102
Supports both JSON and YAML based config files."
103103
},
104104
)
105+
chat_template: str = field(
106+
default=None,
107+
metadata={
108+
"help": "chat template to use for tokenization. \
109+
No need to pass this if the tokenizer already has a chat_template \
110+
if passed, it will overwrite tokenizer.chat_template if it exists"
111+
},
112+
)
113+
instruction_template: str = field(
114+
default=None,
115+
metadata={
116+
"help": "Should be provided for chat training. \
117+
Piece of text that determines the start of human response"
118+
},
119+
)
105120

106121

107122
@dataclass

tuning/data/data_preprocessing_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
def get_data_collator(
2626
packing: bool,
27-
response_template: Optional[str],
2827
tokenizer: AutoTokenizer,
29-
is_traindata_tokenized: bool,
3028
max_seq_length: int,
29+
response_template: Optional[str],
30+
instruction_template: Optional[str],
31+
is_traindata_tokenized: bool,
3132
) -> Callable:
3233
"""Create and return the the appropriate collator type based on the configuration for packing,
3334
response_template, and dataset_text_field.
@@ -49,6 +50,20 @@ def get_data_collator(
4950
Callable collator to be leveraged by the trainer.
5051
"""
5152

53+
if response_template and instruction_template:
54+
# response_template_ids = tokenizer.encode(
55+
# response_template, add_special_tokens=False
56+
# )[2:]
57+
# intruction_template_ids = tokenizer.encode(
58+
# instruction_template, add_special_tokens=False
59+
# )[2:]
60+
return DataCollatorForCompletionOnlyLM(
61+
response_template=response_template,
62+
instruction_template=instruction_template,
63+
tokenizer=tokenizer,
64+
ignore_index=configs.IGNORE_INDEX,
65+
)
66+
5267
if not packing:
5368
# TODO: near term - how response template ids are parsed out needs to be cleaned.
5469
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,

tuning/data/setup_dataprocessor.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
from tuning.data.data_processors import get_datapreprocessor
3535

3636
# In future we may make the fields configurable
37-
DEFAULT_JSON_INPUT_KEY = "input"
38-
DEFAULT_JSON_OUTPUT_KEY = "output"
37+
DEFAULT_INPUT_COLUMN = "input"
38+
DEFAULT_OUTPUT_COLUMN = "output"
3939

4040
# check if the provided dataset is pretokenized or not
4141
# the check is taken from trl
@@ -145,12 +145,12 @@ def _get_dataset_formatting_handlers(data_args, packing):
145145
return [handler], dataset_text_field
146146

147147

148-
### Data format 3
149-
def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs):
148+
### Default Data format
149+
def _get_default_dataset_handlers(data_args, tokenizer_kwargs):
150150

151151
fn_kwargs = {}
152-
fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY
153-
fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY
152+
fn_kwargs["input_field_name"] = DEFAULT_INPUT_COLUMN
153+
fn_kwargs["output_field_name"] = DEFAULT_OUTPUT_COLUMN
154154
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
155155

156156
kwargs = {
@@ -171,7 +171,9 @@ def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs):
171171
# If a text field is specified, append the tokenizer's EOS token to it.
172172
# If a formatter template is provided, apply it and save the result.
173173
# Data remains un-tokenized.
174-
# Data Format 3: JSON Dataset with Input/Output Fields
174+
# Data Format 3: Chat datasets
175+
# User provides response_template and instruction_template.
176+
# Default Data Format: Dataset with Input/Output Fields
175177
# Combine input and output fields, tokenize the data, and apply input attention masking.
176178
# Requires both input and output fields; throws an error if missing.
177179
def _process_raw_data_args(
@@ -231,9 +233,13 @@ def _process_raw_data_args(
231233
handlers, dataset_text_field = _get_dataset_formatting_handlers(
232234
data_args, packing
233235
)
236+
elif data_args.instruction_template and data_args.response_template:
237+
# Data Format 3: Chat dataset with instruction and response template
238+
# We don't do processing for chat dataset
239+
handlers, dataset_text_field = [], None
234240
else:
235-
# Data Format 3: JSON Dataset with Input/Output Fields
236-
handlers, dataset_text_field = _get_default_json_dataset_handlers(
241+
# Default Data Format: Dataset with Input/Output Fields
242+
handlers, dataset_text_field = _get_default_dataset_handlers(
237243
data_args, tokenizer_kwargs
238244
)
239245

@@ -299,13 +305,14 @@ def process_dataargs(
299305

300306
data_collator = get_data_collator(
301307
train_args.packing,
302-
data_args.response_template,
303-
tokenizer,
308+
tokenizer=tokenizer,
309+
max_seq_length=max_seq_length,
310+
response_template=data_args.response_template,
311+
instruction_template=data_args.instruction_template,
304312
# Note: This check should not be removed.
305313
# Its important to recompute this post handling to
306314
# check if we already tokenized the dataset or not.
307-
is_pretokenized_dataset(train_dataset),
308-
max_seq_length,
315+
is_traindata_tokenized=is_pretokenized_dataset(train_dataset),
309316
)
310317

311318
dataset_kwargs = {}

tuning/sft_trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,16 @@ def train(
285285
multiple_of=model_args.embedding_size_multiple_of,
286286
)
287287

288+
if data_args.chat_template:
289+
logger.info("adding chat_template to the tokenizer")
290+
if tokenizer.chat_template:
291+
logger.warning(
292+
"replacing existing chat_template %s with the given chat_template %s",
293+
tokenizer.chat_template,
294+
data_args.chat_template,
295+
)
296+
tokenizer.chat_template = data_args.chat_template
297+
288298
# Configure the collator and validate args related to packing prior to formatting the dataset
289299
data_collator = None
290300
logger.info("Packing is set to %s ", train_args.packing)

0 commit comments

Comments
 (0)