Skip to content

Commit e045ca7

Browse files
committed
Add code comments and make code path clearer.
Remove packing check as packing support for pretokenised data is merged to trl. See huggingface/trl#2011 Signed-off-by: Dushyant Behl <[email protected]>
1 parent 7621173 commit e045ca7

File tree

6 files changed

+302
-376
lines changed

6 files changed

+302
-376
lines changed

tests/data/test_data_handlers.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL
2525

2626
# Local
27-
from tuning.data.data_handlers import apply_custom_data_formatting_template
27+
from tuning.data.data_handlers import (
28+
apply_custom_data_formatting_template,
29+
combine_sequence,
30+
)
2831

2932

3033
def test_apply_custom_formatting_template():
@@ -71,3 +74,37 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
7174
"template": template,
7275
},
7376
)
77+
78+
79+
@pytest.mark.parametrize(
80+
"input_element,output_element,expected_res",
81+
[
82+
("foo ", "bar", "foo bar"),
83+
("foo\n", "bar", "foo\nbar"),
84+
("foo\t", "bar", "foo\tbar"),
85+
("foo", "bar", "foo bar"),
86+
],
87+
)
88+
def test_combine_sequence(input_element, output_element, expected_res):
89+
"""Ensure that input / output elements are combined with correct whitespace handling."""
90+
comb_seq = combine_sequence(input_element, output_element)
91+
assert isinstance(comb_seq, str)
92+
assert comb_seq == expected_res
93+
94+
95+
@pytest.mark.parametrize(
96+
"input_element,output_element,expected_res",
97+
[
98+
("foo ", "bar", "foo bar"),
99+
("foo\n", "bar", "foo\nbar"),
100+
("foo\t", "bar", "foo\tbar"),
101+
("foo", "bar", "foo bar"),
102+
],
103+
)
104+
def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
105+
"""Ensure that input / output elements are combined with correct whitespace handling."""
106+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
107+
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token)
108+
expected_res += tokenizer.eos_token
109+
assert isinstance(comb_seq, str)
110+
assert comb_seq == expected_res

tests/data/test_data_preprocessing_utils.py

Lines changed: 11 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -43,53 +43,15 @@
4343
# Local
4444
from tuning.config import configs
4545
from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig
46-
from tuning.data.data_preprocessing_utils import (
47-
combine_sequence,
48-
get_data_collator,
49-
validate_data_args,
50-
)
51-
from tuning.data.data_processors import HFBasedDataPreProcessor, get_datapreprocessor
46+
from tuning.data.data_preprocessing_utils import get_data_collator
47+
from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor
5248
from tuning.data.setup_dataprocessor import (
5349
_process_dataconfig_file,
5450
is_pretokenized_dataset,
5551
process_dataargs,
5652
)
5753

5854

59-
@pytest.mark.parametrize(
60-
"input_element,output_element,expected_res",
61-
[
62-
("foo ", "bar", "foo bar"),
63-
("foo\n", "bar", "foo\nbar"),
64-
("foo\t", "bar", "foo\tbar"),
65-
("foo", "bar", "foo bar"),
66-
],
67-
)
68-
def test_combine_sequence(input_element, output_element, expected_res):
69-
"""Ensure that input / output elements are combined with correct whitespace handling."""
70-
comb_seq = combine_sequence(input_element, output_element)
71-
assert isinstance(comb_seq, str)
72-
assert comb_seq == expected_res
73-
74-
75-
@pytest.mark.parametrize(
76-
"input_element,output_element,expected_res",
77-
[
78-
("foo ", "bar", "foo bar"),
79-
("foo\n", "bar", "foo\nbar"),
80-
("foo\t", "bar", "foo\tbar"),
81-
("foo", "bar", "foo bar"),
82-
],
83-
)
84-
def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
85-
"""Ensure that input / output elements are combined with correct whitespace handling."""
86-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
87-
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token)
88-
expected_res += tokenizer.eos_token
89-
assert isinstance(comb_seq, str)
90-
assert comb_seq == expected_res
91-
92-
9355
@pytest.mark.parametrize(
9456
"datafile, column_names",
9557
[
@@ -222,7 +184,6 @@ def test_load_dataset_without_dataconfig_and_datafile():
222184
)
223185
def test_is_pretokenized_data(data, result):
224186
"""Ensure that the correct collator type is fetched based on the data args"""
225-
226187
assert is_pretokenized_dataset(data=data) == result
227188

228189

@@ -361,43 +322,16 @@ def test_get_data_collator(
361322
),
362323
],
363324
)
364-
def test_validate_args(data_args, packing):
325+
def test_process_data_args_throws_error_where_needed(data_args, packing):
365326
"""Ensure that respective errors are thrown for incorrect data arguments"""
366327
with pytest.raises(ValueError):
367-
is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path)
368-
is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path)
369-
validate_data_args(
370-
data_args, packing, is_traindata_tokenized, is_evaldata_tokenized
328+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
329+
TRAIN_ARGS = configs.TrainingArguments(
330+
packing=packing,
331+
max_seq_length=1024,
332+
output_dir="tmp", # Not needed but positional
371333
)
372-
373-
374-
@pytest.mark.parametrize(
375-
"data_args, packing",
376-
[
377-
# pretokenized train dataset and no validation dataset passed
378-
(
379-
configs.DataArguments(
380-
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
381-
),
382-
False,
383-
),
384-
# pretokenized train and validation datasets
385-
(
386-
configs.DataArguments(
387-
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
388-
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
389-
),
390-
False,
391-
),
392-
],
393-
)
394-
def test_validate_args_pretokenized(data_args, packing):
395-
"""Ensure that supported data args do not error out when passing pretokenized datasets"""
396-
is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path)
397-
is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path)
398-
validate_data_args(
399-
data_args, packing, is_traindata_tokenized, is_evaldata_tokenized
400-
)
334+
(_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS)
401335

402336

403337
@pytest.mark.parametrize(
@@ -448,11 +382,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
448382
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)
449383

450384
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
451-
packing = (False,)
452-
max_seq_length = 1024
453-
(train_set, _, _, _, _, _) = _process_dataconfig_file(
454-
data_args, tokenizer, packing, max_seq_length
455-
)
385+
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
456386
assert isinstance(train_set, Dataset)
457387
if datasets_name == "text_dataset_input_output_masking":
458388
column_names = set(["input_ids", "attention_mask", "labels"])
@@ -625,7 +555,7 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
625555
"""Test process_dataset_configs for expected output."""
626556
dataprocessor_config = DataPreProcessorConfig()
627557
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
628-
processor = HFBasedDataPreProcessor(
558+
processor = DataPreProcessor(
629559
processor_config=dataprocessor_config,
630560
tokenizer=tokenizer,
631561
)

tuning/data/data_handlers.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,55 @@
1515
# Definition of some predefined data preprocessing functions that we need.
1616

1717
# Standard
18-
from typing import Dict
18+
from typing import Dict, List
19+
import re
1920

2021
# Third Party
2122
from transformers import AutoTokenizer
2223

23-
# Local
24-
from tuning.data.data_preprocessing_utils import combine_sequence, custom_data_formatter
24+
25+
### Utils for custom masking / manipulating input / output strs, etc
26+
def combine_sequence(input_element: str, output_element: str, eos_token: str = ""):
27+
"""Combines / concatenates input & output element.
28+
29+
Args:
30+
input_element: str
31+
Input component of the combined sequence.
32+
output_element: str
33+
Output component of the combined sequence.
34+
eos_token: str
35+
EOS token associated with the tokenizer. \
36+
If passed, it will be concatenated at end
37+
38+
Returns:
39+
str
40+
Sequence combined with whitespace.
41+
"""
42+
if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith(
43+
(" ", "\n", "\t")
44+
):
45+
return input_element + " " + output_element + eos_token
46+
return input_element + output_element + eos_token
2547

2648

2749
def tokenize_and_apply_input_masking(
2850
element: Dict[str, str],
2951
tokenizer: AutoTokenizer,
52+
column_names: List[str],
3053
input_field_name: str,
3154
output_field_name: str,
3255
**tokenizer_kwargs,
3356
):
57+
if (input_field_name or output_field_name) not in column_names:
58+
raise ValueError(
59+
f"Dataset should contain {input_field_name} \
60+
and {output_field_name} field if \
61+
no dataset_text_field or data_formatter_template specified"
62+
)
63+
3464
input_text = element[input_field_name]
3565
output_text = element[output_field_name]
3666

37-
# TODO: Eventually move the code here
3867
combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)
3968

4069
fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {})
@@ -56,7 +85,10 @@ def tokenize_and_apply_input_masking(
5685

5786

5887
def apply_dataset_formatting(
59-
element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs
88+
element: Dict[str, str],
89+
tokenizer: AutoTokenizer,
90+
dataset_text_field: str,
91+
**kwargs,
6092
):
6193
return {
6294
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
@@ -85,8 +117,22 @@ def apply_custom_data_formatting_template(
85117

86118
template += tokenizer.eos_token
87119

88-
# TODO: Eventually move the code here.
89-
return custom_data_formatter(element, template, dataset_text_field)
120+
def replace_text(match_obj):
121+
captured_groups = match_obj.groups()
122+
if len(captured_groups) != 1:
123+
raise ValueError(
124+
"Unexpectedly captured multiple groups in template formatting"
125+
)
126+
127+
index_object = captured_groups[0]
128+
if index_object not in element:
129+
raise KeyError("Requested template string is not a valid key in dict")
130+
131+
return element[index_object]
132+
133+
return {
134+
dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template)
135+
}
90136

91137

92138
AVAILABLE_DATA_HANDLERS = {

0 commit comments

Comments
 (0)