Skip to content

Commit b655e1a

Browse files
authored
minor refactor to allow modular functions (#224)
* minor refactor to allow modular functions Signed-off-by: Sukriti-Sharma4 <[email protected]> * minor fix in import Signed-off-by: Sukriti-Sharma4 <[email protected]> * minor fix to imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix linting Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix formatting Signed-off-by: Sukriti-Sharma4 <[email protected]> --------- Signed-off-by: Sukriti-Sharma4 <[email protected]>
1 parent 0be40e0 commit b655e1a

File tree

3 files changed

+162
-99
lines changed

3 files changed

+162
-99
lines changed

tests/utils/test_preprocessing_utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414

1515
# Local
16+
from tuning.config import configs
1617
from tuning.utils.preprocessing_utils import (
1718
combine_sequence,
1819
get_data_trainer_kwargs,
@@ -180,14 +181,29 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data):
180181
assert trainer_kwargs["formatting_func"] is not None
181182

182183

183-
# Tests for fetching train args
184+
# Tests for validating data args
185+
# Invalid args return ValueError
184186
@pytest.mark.parametrize(
185-
"dataset_text_field, response_template",
187+
"data_args, packing",
186188
[
187-
("input", None),
188-
(None, "output"),
189+
# dataset_text_field with no response_template
190+
(
191+
configs.DataArguments(
192+
training_data_path=TWITTER_COMPLAINTS_DATA,
193+
dataset_text_field="output",
194+
),
195+
False,
196+
),
197+
# response template with no dataset_text_field or formatter
198+
(
199+
configs.DataArguments(
200+
training_data_path=TWITTER_COMPLAINTS_DATA,
201+
response_template="\n### Label:",
202+
),
203+
False,
204+
),
189205
],
190206
)
191-
def test_validate_args(dataset_text_field, response_template):
207+
def test_validate_args(data_args, packing):
192208
with pytest.raises(ValueError):
193-
validate_data_args(dataset_text_field, response_template)
209+
validate_data_args(data_args, packing)

tuning/sft_trainer.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TrainerCallback,
3535
)
3636
from transformers.utils import is_accelerate_available, logging
37-
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
37+
from trl import SFTConfig, SFTTrainer
3838
import datasets
3939
import fire
4040
import transformers
@@ -62,6 +62,7 @@
6262
USER_ERROR_EXIT_CODE,
6363
write_termination_log,
6464
)
65+
from tuning.utils.preprocessing_utils import get_data_collator, validate_data_args
6566

6667

6768
def train(
@@ -195,14 +196,6 @@ def train(
195196
}
196197
)
197198

198-
# TODO: near term - how response template ids are parsed out needs to be cleaned.
199-
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
200-
# otherwise template is not found. We will create issue to clean this out after we discuss
201-
# data formats and collators we will support.
202-
response_template_ids = tokenizer.encode(
203-
data_args.response_template, add_special_tokens=False
204-
)[2:]
205-
206199
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
207200
logger.info("Max sequence length is %s", max_seq_length)
208201
if train_args.max_seq_length > tokenizer.model_max_length:
@@ -244,31 +237,14 @@ def train(
244237
packing = True
245238
else:
246239
logger.info("Packing is set to False")
247-
if data_args.response_template is None:
248-
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
249-
# We should do this validation up front, then do the encoding, then handle the collator
250-
raise ValueError("Response template is None, needs to be set for training")
251-
data_collator = DataCollatorForCompletionOnlyLM(
252-
response_template_ids,
253-
tokenizer=tokenizer,
254-
ignore_index=configs.IGNORE_INDEX,
255-
)
256240
packing = False
257241

258-
# Currently we support formatted datasets with single sequence instances.
259-
if not (data_args.dataset_text_field or data_args.data_formatter_template):
260-
raise ValueError(
261-
"dataset_text_field and data_formatter_template are None. \
262-
One of them needs to be set for training"
263-
)
264-
# Only one of dataset_text_field or data_formatter_template should be set.
265-
if data_args.dataset_text_field and data_args.data_formatter_template:
266-
raise ValueError(
267-
"dataset_text_field and data_formatter_template are both set,\
268-
but are mutually exclusive options"
269-
)
242+
# Validate if data args are set properly
243+
validate_data_args(data_args, packing)
244+
data_collator = get_data_collator(packing, data_args.response_template, tokenizer)
270245

271246
# load the data by parsing JSON
247+
### TODO: all the jSON file formatting will be moved to a separate function
272248
data_files = {"train": data_args.training_data_path}
273249
if data_args.validation_data_path:
274250
data_files["validation"] = data_args.validation_data_path
@@ -310,6 +286,7 @@ def train(
310286
logger.info(
311287
"Validation dataset length is %s", len(formatted_validation_dataset)
312288
)
289+
### JSON file formatting ends here
313290

314291
if framework is not None and framework.requires_agumentation:
315292
model, (peft_config,) = framework.augmentation(

tuning/utils/preprocessing_utils.py

Lines changed: 133 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,141 @@
2525
from tuning.config import configs
2626

2727

28-
def validate_data_args(
28+
def validate_data_args(data_args: configs.DataArguments, packing: bool):
29+
30+
assert isinstance(
31+
data_args.training_data_path, str
32+
), "Training data path has to be set and str"
33+
34+
# Dataset containing single sequence needs a response template for masking
35+
if data_args.response_template is None and data_args.dataset_text_field is not None:
36+
if packing is False:
37+
raise ValueError(
38+
"Since dataset_text_field is provided and packing is disabled, \
39+
needs a corresponding response template for masking"
40+
)
41+
42+
# Currently if packing is false, we require a response_template. This may change in future.
43+
if packing is False:
44+
if data_args.response_template is None:
45+
raise ValueError(
46+
"Response template is None, needs to be set for training \
47+
with packing disabled."
48+
)
49+
50+
if data_args.response_template:
51+
# To use Response template, pass datasets with single sequence instances \
52+
# or a formatter template to create single sequence on the fly.
53+
if not (data_args.dataset_text_field or data_args.data_formatter_template):
54+
raise ValueError(
55+
"dataset_text_field and data_formatter_template are None. \
56+
One of them needs to be set to use response_template"
57+
)
58+
# Only one of dataset_text_field or data_formatter_template should be set.
59+
if data_args.dataset_text_field and data_args.data_formatter_template:
60+
raise ValueError(
61+
"dataset_text_field and data_formatter_template are both set,\
62+
but are mutually exclusive options"
63+
)
64+
# TODO(s) In future seupport two more formats:
65+
# 1. Allow no response template, and JSON with input/output fields and mask input
66+
67+
# 2. Allow pretokenized Dataset besides JSON.
68+
69+
70+
def get_data_collator(
71+
packing: bool,
72+
response_template: Optional[str],
73+
tokenizer: AutoTokenizer,
74+
) -> Callable:
75+
"""Create and return the the appropriate collator type based on the configuration for packing,
76+
response_template, and dataset_text_field.
77+
78+
Args:
79+
packing: bool
80+
Whether or not we should apply packing or not.
81+
response_template: Optional[str]
82+
Response template to be used for formatting by TRL.
83+
tokenizer: AutoTokenizer
84+
Loaded tokenizer object to be used by the collator.
85+
86+
Returns:
87+
Callable
88+
Callable collator to be leveraged by the trainer.
89+
"""
90+
if not packing:
91+
# TODO: near term - how response template ids are parsed out needs to be cleaned.
92+
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
93+
# otherwise template is not found. We will create issue to clean this out after we discuss
94+
# data formats and collators we will support.
95+
if response_template:
96+
response_template_ids = tokenizer.encode(
97+
response_template, add_special_tokens=False
98+
)[2:]
99+
return DataCollatorForCompletionOnlyLM(
100+
response_template=response_template_ids,
101+
tokenizer=tokenizer,
102+
ignore_index=configs.IGNORE_INDEX,
103+
)
104+
# TO DO with future changes,
105+
# 1. Support no packing and seq2seq colator without response template
106+
# # if dataset_text_field is None and response_template is None:
107+
# # Use the seq2seq data collator;
108+
# # Note that this automatically pads labels with -100
109+
# return DataCollatorForSeq2Seq(
110+
# tokenizer=tokenizer, padding=True, max_length=max_sequence_length
111+
# )
112+
# 2. add anything needed for preprocessed input
113+
114+
115+
###################################################################################
116+
### The functions below are not yet used. Iterative development towards new features
117+
118+
119+
def get_data_collator_temp(
120+
packing: bool,
29121
dataset_text_field: Optional[str],
30122
response_template: Optional[str],
31-
):
32-
# Dataset containing single sequence needs a single sequence and a response template
33-
if dataset_text_field is None and response_template is not None:
34-
raise ValueError(
35-
"Needs a corresponding dataset_text_feld \
36-
in which to look for response_template"
37-
)
38-
if response_template is None and dataset_text_field is not None:
39-
raise ValueError(
40-
"Since dataset_text_field is provided, \
41-
needs a corresponding response template for masking"
42-
)
43-
# Dataset containing JSON with fields and a formatter template
44-
# TO DO load JSON and check input/output field is present
123+
max_sequence_length: int,
124+
tokenizer: AutoTokenizer,
125+
) -> Callable:
126+
"""Create and return the the appropriate collator type based on the configuration for packing,
127+
response_template, and dataset_text_field.
45128
46-
# in future : pretokenized Dataset may be added.
129+
Args:
130+
packing: bool
131+
Whether or not we should apply packing or not.
132+
dataset_text_field: Optional[str]
133+
Dataset text field fto be used for formatting by TRL.
134+
response_template: Optional[str]
135+
Response template to be used for formatting by TRL.
136+
max_sequence_length: int
137+
Max sequence length to be used for sequence tokenization.
138+
tokenizer: AutoTokenizer
139+
Loaded tokenizer object to be used by the collator.
140+
141+
Returns:
142+
Callable
143+
Callable collator to be leveraged by the trainer.
144+
"""
145+
if not packing:
146+
if dataset_text_field is None and response_template is None:
147+
# Use the seq2seq data collator; note that this automatically pads labels with -100
148+
return DataCollatorForSeq2Seq(
149+
tokenizer=tokenizer, padding=True, max_length=max_sequence_length
150+
)
151+
# TODO: near term - how response template ids are parsed out needs to be cleaned.
152+
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
153+
# otherwise template is not found. We will create issue to clean this out after we discuss
154+
# data formats and collators we will support.
155+
response_template_ids = tokenizer.encode(
156+
response_template, add_special_tokens=False
157+
)[2:]
158+
return DataCollatorForCompletionOnlyLM(
159+
response_template=response_template_ids,
160+
tokenizer=tokenizer,
161+
ignore_index=configs.IGNORE_INDEX,
162+
)
47163

48164

49165
def get_data_trainer_kwargs(
@@ -82,7 +198,7 @@ def get_data_trainer_kwargs(
82198
Dict[str, Any]
83199
Data related kwargs to be used by the SFT Trainer.
84200
"""
85-
data_collator = get_data_collator(
201+
data_collator = get_data_collator_temp(
86202
packing, dataset_text_field, response_template, max_sequence_length, tokenizer
87203
)
88204
eval_dataset = None
@@ -122,52 +238,6 @@ def get_data_trainer_kwargs(
122238
return data_kwargs
123239

124240

125-
def get_data_collator(
126-
packing: bool,
127-
dataset_text_field: Optional[str],
128-
response_template: Optional[str],
129-
max_sequence_length: int,
130-
tokenizer: AutoTokenizer,
131-
) -> Callable:
132-
"""Create and return the the appropriate collator type based on the configuration for packing,
133-
response_template, and dataset_text_field.
134-
135-
Args:
136-
packing: bool
137-
Whether or not we should apply packing or not.
138-
dataset_text_field: Optional[str]
139-
Dataset text field fto be used for formatting by TRL.
140-
response_template: Optional[str]
141-
Response template to be used for formatting by TRL.
142-
max_sequence_length: int
143-
Max sequence length to be used for sequence tokenization.
144-
tokenizer: AutoTokenizer
145-
Loaded tokenizer object to be used by the collator.
146-
147-
Returns:
148-
Callable
149-
Callable collator to be leveraged by the trainer.
150-
"""
151-
if not packing:
152-
if dataset_text_field is None and response_template is None:
153-
# Use the seq2seq data collator; note that this automatically pads labels with -100
154-
return DataCollatorForSeq2Seq(
155-
tokenizer=tokenizer, padding=True, max_length=max_sequence_length
156-
)
157-
# TODO: near term - how response template ids are parsed out needs to be cleaned.
158-
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
159-
# otherwise template is not found. We will create issue to clean this out after we discuss
160-
# data formats and collators we will support.
161-
response_template_ids = tokenizer.encode(
162-
response_template, add_special_tokens=False
163-
)[2:]
164-
return DataCollatorForCompletionOnlyLM(
165-
response_template=response_template_ids,
166-
tokenizer=tokenizer,
167-
ignore_index=configs.IGNORE_INDEX,
168-
)
169-
170-
171241
def get_formatted_dataset(
172242
data_path: str, dataset_text_field: str, tokenizer: AutoTokenizer
173243
) -> Dataset:

0 commit comments

Comments
 (0)