Skip to content

Commit 4334d6c

Browse files
Formatting consolidation main (#216)
* Add prepreprocessing utils for pretokenized datasets Signed-off-by: Alex-Brooks <[email protected]> * Add twitter input/output data Signed-off-by: Alex-Brooks <[email protected]> * Add tests for data preprocessing utilities Signed-off-by: Alex-Brooks <[email protected]> * Formatting, add hack for sidestepping validation Signed-off-by: Alex-Brooks <[email protected]> * Fix linting errors in data gen Signed-off-by: Alex-Brooks <[email protected]> * Add end to end pretokenized tests, formatting Signed-off-by: Alex-Brooks <[email protected]> * Add docstrings for preprocessor utils Signed-off-by: Alex-Brooks <[email protected]> * Rebase tests to new structure Signed-off-by: Alex-Brooks <[email protected]> * fix formatting Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix linting Signed-off-by: Sukriti-Sharma4 <[email protected]> --------- Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: Sukriti-Sharma4 <[email protected]> Co-authored-by: Alex-Brooks <[email protected]>
1 parent 3f05c67 commit 4334d6c

File tree

4 files changed

+588
-0
lines changed

4 files changed

+588
-0
lines changed

tests/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
### Constants used for data
2121
DATA_DIR = os.path.join(os.path.dirname(__file__))
2222
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
23+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT = os.path.join(
24+
DATA_DIR, "twitter_complaints_input_output.json"
25+
)
2326
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json")
2427
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
2528
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{"ID": 0, "Label": 2, "input": "@HMRCcustomers No this is my first job", "output": "no complaint"}
2+
{"ID": 1, "Label": 2, "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", "output": "no complaint"}
3+
{"ID": 2, "Label": 1, "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", "output": "complaint"}
4+
{"ID": 3, "Label": 1, "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", "output": "complaint"}
5+
{"ID": 4, "Label": 2, "input": "Couples wallpaper, so cute. :) #BrothersAtHome", "output": "no complaint"}
6+
{"ID": 5, "Label": 2, "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", "output": "no complaint"}
7+
{"ID": 6, "Label": 2, "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", "output": "no complaint"}
8+
{"ID": 7, "Label": 1, "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", "output": "complaint"}
9+
{"ID": 8, "Label": 1, "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", "output": "complaint"}
10+
{"ID": 9, "Label": 2, "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", "output": "no complaint"}
11+
{"ID": 10, "Label": 2, "input": "@NortonSupport Thanks much.", "output": "no complaint"}
12+
{"ID": 11, "Label": 2, "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", "output": "no complaint"}
13+
{"ID": 12, "Label": 2, "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", "output": "no complaint"}
14+
{"ID": 13, "Label": 2, "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", "output": "no complaint"}
15+
{"ID": 14, "Label": 2, "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", "output": "no complaint"}
16+
{"ID": 15, "Label": 2, "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", "output": "no complaint"}
17+
{"ID": 16, "Label": 2, "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", "output": "no complaint"}
18+
{"ID": 17, "Label": 2, "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", "output": "no complaint"}
19+
{"ID": 18, "Label": 2, "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", "output": "no complaint"}
20+
{"ID": 19, "Label": 1, "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", "output": "complaint"}
21+
{"ID": 20, "Label": 1, "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", "output": "complaint"}
22+
{"ID": 21, "Label": 1, "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", "output": "complaint"}
23+
{"ID": 22, "Label": 2, "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", "output": "no complaint"}
24+
{"ID": 23, "Label": 1, "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", "output": "complaint"}
25+
{"ID": 24, "Label": 2, "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", "output": "no complaint"}
26+
{"ID": 25, "Label": 2, "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", "output": "no complaint"}
27+
{"ID": 26, "Label": 2, "input": "Just posted a photo https://t.co/RShFwCjPHu", "output": "no complaint"}
28+
{"ID": 27, "Label": 2, "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", "output": "no complaint"}
29+
{"ID": 28, "Label": 1, "input": "@TopmanAskUs please just give me my money back.", "output": "complaint"}
30+
{"ID": 29, "Label": 2, "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", "output": "no complaint"}
31+
{"ID": 30, "Label": 2, "input": "@FitbitSupport when are you launching new clock faces for Indian market", "output": "no complaint"}
32+
{"ID": 31, "Label": 1, "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", "output": "complaint"}
33+
{"ID": 32, "Label": 1, "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", "output": "complaint"}
34+
{"ID": 33, "Label": 1, "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", "output": "complaint"}
35+
{"ID": 34, "Label": 2, "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", "output": "no complaint"}
36+
{"ID": 35, "Label": 2, "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", "output": "no complaint"}
37+
{"ID": 36, "Label": 1, "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", "output": "complaint"}
38+
{"ID": 37, "Label": 2, "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", "output": "no complaint"}
39+
{"ID": 38, "Label": 2, "input": "@fernrocks most definitely the latter for me", "output": "no complaint"}
40+
{"ID": 39, "Label": 1, "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", "output": "complaint"}
41+
{"ID": 40, "Label": 2, "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", "output": "no complaint"}
42+
{"ID": 41, "Label": 2, "input": "@Schrapnel @comcast RIP me", "output": "no complaint"}
43+
{"ID": 42, "Label": 2, "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", "output": "no complaint"}
44+
{"ID": 43, "Label": 2, "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", "output": "no complaint"}
45+
{"ID": 44, "Label": 2, "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", "output": "no complaint"}
46+
{"ID": 45, "Label": 2, "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", "output": "no complaint"}
47+
{"ID": 46, "Label": 2, "input": "@Wavy2Timez for real", "output": "no complaint"}
48+
{"ID": 47, "Label": 1, "input": "@KenyaPower_Care no power in south b area... is it scheduled.", "output": "complaint"}
49+
{"ID": 48, "Label": 1, "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", "output": "complaint"}
50+
{"ID": 49, "Label": 1, "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", "output": "complaint"}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Third Party
2+
from datasets import Dataset
3+
from datasets.exceptions import DatasetGenerationError
4+
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
5+
from trl import DataCollatorForCompletionOnlyLM
6+
import pytest
7+
8+
# First Party
9+
from tests.data import (
10+
MALFORMATTED_DATA,
11+
TWITTER_COMPLAINTS_DATA,
12+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT,
13+
)
14+
15+
# Local
16+
from tuning.utils.preprocessing_utils import (
17+
combine_sequence,
18+
get_data_trainer_kwargs,
19+
get_preprocessed_dataset,
20+
load_hf_dataset_from_jsonl_file,
21+
validate_data_args,
22+
)
23+
24+
25+
@pytest.mark.parametrize(
26+
"input_element,output_element,expected_res",
27+
[
28+
("foo ", "bar", "foo bar"),
29+
("foo\n", "bar", "foo\nbar"),
30+
("foo\t", "bar", "foo\tbar"),
31+
("foo", "bar", "foo bar"),
32+
],
33+
)
34+
def test_combine_sequence(input_element, output_element, expected_res):
35+
"""Ensure that input / output elements are combined with correct whitespace handling."""
36+
comb_seq = combine_sequence(input_element, output_element)
37+
assert isinstance(comb_seq, str)
38+
assert comb_seq == expected_res
39+
40+
41+
# Tests for loading the dataset from disk
42+
def test_load_hf_dataset_from_jsonl_file():
43+
input_field_name = "Tweet text"
44+
output_field_name = "text_label"
45+
data = load_hf_dataset_from_jsonl_file(
46+
TWITTER_COMPLAINTS_DATA,
47+
input_field_name=input_field_name,
48+
output_field_name=output_field_name,
49+
)
50+
# Our dataset should contain dicts that contain the input / output field name types
51+
next_data = next(iter(data))
52+
assert input_field_name in next_data
53+
assert output_field_name in next_data
54+
55+
56+
def test_load_hf_dataset_from_jsonl_file_wrong_keys():
57+
"""Ensure that we explode if the keys are not in the jsonl file."""
58+
with pytest.raises(DatasetGenerationError):
59+
load_hf_dataset_from_jsonl_file(
60+
TWITTER_COMPLAINTS_DATA, input_field_name="foo", output_field_name="bar"
61+
)
62+
63+
64+
def test_load_hf_dataset_from_malformatted_data():
65+
"""Ensure that we explode if the data is not properly formatted."""
66+
# NOTE: The actual keys don't matter here
67+
with pytest.raises(DatasetGenerationError):
68+
load_hf_dataset_from_jsonl_file(
69+
MALFORMATTED_DATA, input_field_name="foo", output_field_name="bar"
70+
)
71+
72+
73+
def test_load_hf_dataset_from_jsonl_file_duplicate_keys():
74+
"""Ensure we cannot have the same key for input / output."""
75+
with pytest.raises(ValueError):
76+
load_hf_dataset_from_jsonl_file(
77+
TWITTER_COMPLAINTS_DATA,
78+
input_field_name="Tweet text",
79+
output_field_name="Tweet text",
80+
)
81+
82+
83+
# Tests for custom masking / preprocessing logic
84+
@pytest.mark.parametrize("max_sequence_length", [1, 10, 100, 1000])
85+
def test_get_preprocessed_dataset(max_sequence_length):
86+
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0")
87+
preprocessed_data = get_preprocessed_dataset(
88+
data_path=TWITTER_COMPLAINTS_DATA,
89+
tokenizer=tokenizer,
90+
max_sequence_length=max_sequence_length,
91+
input_field_name="Tweet text",
92+
output_field_name="text_label",
93+
)
94+
for tok_res in preprocessed_data:
95+
# Since the padding is left to the collator, there should be no 0s in the attention mask yet
96+
assert sum(tok_res["attention_mask"]) == len(tok_res["attention_mask"])
97+
# If the source text isn't empty, we start with masked inputs
98+
assert tok_res["labels"][0] == -100
99+
# All keys in the produced record must be the same length
100+
key_lengths = {len(tok_res[k]) for k in tok_res.keys()}
101+
assert len(key_lengths) == 1
102+
# And also that length should be less than or equal to the max length depending on if we
103+
# are going up to / over the max size and truncating - padding is handled separately
104+
assert key_lengths.pop() <= max_sequence_length
105+
106+
107+
# Tests for fetching train args
108+
@pytest.mark.parametrize(
109+
"use_validation_data, collator_type, packing",
110+
[
111+
(True, None, True),
112+
(False, None, True),
113+
(True, DataCollatorForCompletionOnlyLM, False),
114+
(False, DataCollatorForCompletionOnlyLM, False),
115+
],
116+
)
117+
def test_get_trainer_kwargs_with_response_template_and_text_field(
118+
use_validation_data, collator_type, packing
119+
):
120+
training_data_path = TWITTER_COMPLAINTS_DATA
121+
validation_data_path = training_data_path if use_validation_data else None
122+
# Expected columns in the raw loaded dataset for the twitter data
123+
column_names = set(["Tweet text", "ID", "Label", "text_label", "output"])
124+
trainer_kwargs = get_data_trainer_kwargs(
125+
training_data_path=training_data_path,
126+
validation_data_path=validation_data_path,
127+
packing=packing,
128+
response_template="\n### Label:",
129+
max_sequence_length=100,
130+
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"),
131+
dataset_text_field="output",
132+
)
133+
assert len(trainer_kwargs) == 3
134+
# If we are packing, we should not have a data collator
135+
if collator_type is None:
136+
assert trainer_kwargs["data_collator"] is None
137+
else:
138+
assert isinstance(trainer_kwargs["data_collator"], collator_type)
139+
140+
# We should only have a validation dataset if one is present
141+
if validation_data_path is None:
142+
assert trainer_kwargs["eval_dataset"] is None
143+
else:
144+
assert isinstance(trainer_kwargs["eval_dataset"], Dataset)
145+
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names
146+
147+
assert isinstance(trainer_kwargs["train_dataset"], Dataset)
148+
assert set(trainer_kwargs["train_dataset"].column_names) == column_names
149+
150+
151+
@pytest.mark.parametrize("use_validation_data", [True, False])
152+
def test_get_trainer_kwargs_with_custom_masking(use_validation_data):
153+
training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT
154+
validation_data_path = training_data_path if use_validation_data else None
155+
# Expected columns in the raw loaded dataset for the twitter data
156+
column_names = set(["input_ids", "attention_mask", "labels"])
157+
trainer_kwargs = get_data_trainer_kwargs(
158+
training_data_path=training_data_path,
159+
validation_data_path=validation_data_path,
160+
packing=False,
161+
response_template=None,
162+
max_sequence_length=100,
163+
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"),
164+
dataset_text_field=None,
165+
)
166+
assert len(trainer_kwargs) == 4
167+
# If we are packing, we should not have a data collator
168+
assert isinstance(trainer_kwargs["data_collator"], DataCollatorForSeq2Seq)
169+
170+
# We should only have a validation dataset if one is present
171+
if validation_data_path is None:
172+
assert trainer_kwargs["eval_dataset"] is None
173+
else:
174+
assert isinstance(trainer_kwargs["eval_dataset"], Dataset)
175+
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names
176+
177+
assert isinstance(trainer_kwargs["train_dataset"], Dataset)
178+
assert set(trainer_kwargs["train_dataset"].column_names) == column_names
179+
# Needed to sidestep TRL validation
180+
assert trainer_kwargs["formatting_func"] is not None
181+
182+
183+
# Tests for fetching train args
184+
@pytest.mark.parametrize(
185+
"dataset_text_field, response_template",
186+
[
187+
("input", None),
188+
(None, "output"),
189+
],
190+
)
191+
def test_validate_args(dataset_text_field, response_template):
192+
with pytest.raises(ValueError):
193+
validate_data_args(dataset_text_field, response_template)

0 commit comments

Comments
 (0)