Skip to content

Commit 43df3a4

Browse files
qgallouedeckashif
andauthored
🧳 Move zen generation script and fix tests (huggingface#2393)
* Move zen * step -> stepwise_supervision * Fix train_test_split shuffle issue * Fix tests * Update tests/test_sft_trainer.py Co-authored-by: Kashif Rasul <[email protected]> * Fix typo in key name --------- Co-authored-by: Kashif Rasul <[email protected]>
1 parent baee06f commit 43df3a4

File tree

4 files changed

+39
-35
lines changed

4 files changed

+39
-35
lines changed

examples/datasets/zen.py renamed to scripts/generate_zen_dataset.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ class ScriptArguments:
2828
Fraction of the dataset to include in the test split.
2929
push_to_hub (`bool`, *optional*, defaults to `False`):
3030
Whether to push the dataset to the Hugging Face Hub.
31-
repo_id (`str`, *optional*, defaults to `"trl-lib/zen"`):
31+
repo_id (`str`, *optional*, defaults to `"trl-internal-testing/zen"`):
3232
Hugging Face repository ID to push the dataset to.
3333
"""
3434

3535
test_size: float = 0.1
3636
push_to_hub: bool = False
37-
repo_id: str = "trl-lib/zen"
37+
repo_id: str = "trl-internal-testing/zen"
3838

3939

4040
def main(test_size, push_to_hub, repo_id):
@@ -62,7 +62,7 @@ def main(test_size, push_to_hub, repo_id):
6262
"Namespaces are one honking great idea -- let's do more of those!",
6363
],
6464
})
65-
standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size)
65+
standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
6666
if push_to_hub:
6767
standard_language_modeling_dataset.push_to_hub(repo_id, config_name="standard_language_modeling")
6868

@@ -89,7 +89,7 @@ def main(test_size, push_to_hub, repo_id):
8989
"Namespaces are one honking great",
9090
],
9191
})
92-
standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size)
92+
standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
9393
if push_to_hub:
9494
standard_prompt_only_dataset.push_to_hub(repo_id, config_name="standard_prompt_only")
9595

@@ -137,7 +137,7 @@ def main(test_size, push_to_hub, repo_id):
137137
" idea -- let's do more of those!",
138138
],
139139
})
140-
standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size)
140+
standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
141141
if push_to_hub:
142142
standard_prompt_completion_dataset.push_to_hub(repo_id, config_name="standard_prompt_completion")
143143

@@ -206,7 +206,7 @@ def main(test_size, push_to_hub, repo_id):
206206
" watermelon -- let's plant some!",
207207
],
208208
})
209-
standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size)
209+
standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
210210
if push_to_hub:
211211
standard_preference_dataset.push_to_hub(repo_id, config_name="standard_preference")
212212

@@ -254,7 +254,7 @@ def main(test_size, push_to_hub, repo_id):
254254
"Namespaces are one honking great watermelon -- let's plant some!",
255255
],
256256
})
257-
standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size)
257+
standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
258258
if push_to_hub:
259259
standard_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="standard_implicit_prompt_preference")
260260

@@ -303,11 +303,11 @@ def main(test_size, push_to_hub, repo_id):
303303
],
304304
"label": [True, False, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, False, False],
305305
})
306-
standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size)
306+
standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
307307
if push_to_hub:
308308
standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference")
309309

310-
standard_step_dataset = Dataset.from_dict({
310+
standard_stepwise_supervision_dataset = Dataset.from_dict({
311311
"prompt": [
312312
"Beautiful is better than",
313313
"Explicit is better than",
@@ -350,7 +350,7 @@ def main(test_size, push_to_hub, repo_id):
350350
[" of those great ideas,", " that solve many problems."],
351351
[" the code should still aim for balance."],
352352
],
353-
"label": [
353+
"labels": [
354354
[False, True],
355355
[False, True, False],
356356
[False, True],
@@ -371,9 +371,9 @@ def main(test_size, push_to_hub, repo_id):
371371
[False]
372372
]
373373
})
374-
standard_step_dataset = standard_step_dataset.train_test_split(test_size=test_size)
374+
standard_stepwise_supervision_dataset = standard_stepwise_supervision_dataset.train_test_split(test_size=test_size, shuffle=False)
375375
if push_to_hub:
376-
standard_step_dataset.push_to_hub(repo_id, config_name="standard_step")
376+
standard_stepwise_supervision_dataset.push_to_hub(repo_id, config_name="standard_stepwise_supervision")
377377

378378
conversational_language_modeling_dataset = Dataset.from_dict({
379379
"messages": [
@@ -398,7 +398,7 @@ def main(test_size, push_to_hub, repo_id):
398398
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}],
399399
],
400400
})
401-
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size)
401+
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
402402
if push_to_hub:
403403
conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling")
404404

@@ -425,7 +425,7 @@ def main(test_size, push_to_hub, repo_id):
425425
[{"role": "user", "content": "Any great ideas?"}],
426426
],
427427
})
428-
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size)
428+
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
429429
if push_to_hub:
430430
conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only")
431431

@@ -473,7 +473,7 @@ def main(test_size, push_to_hub, repo_id):
473473
[{"role": "assistant", "content": "Namespaces are one honking great idea."}],
474474
],
475475
})
476-
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size)
476+
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
477477
if push_to_hub:
478478
conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion")
479479

@@ -542,7 +542,7 @@ def main(test_size, push_to_hub, repo_id):
542542
[{"role": "assistant", "content": "Recursion."}],
543543
],
544544
})
545-
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size)
545+
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
546546
if push_to_hub:
547547
conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference")
548548

@@ -590,7 +590,7 @@ def main(test_size, push_to_hub, repo_id):
590590
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Recursion."}],
591591
],
592592
})
593-
conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size)
593+
conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
594594
if push_to_hub:
595595
conversational_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="conversational_implicit_prompt_preference")
596596

@@ -639,7 +639,7 @@ def main(test_size, push_to_hub, repo_id):
639639
],
640640
"label": [True, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, True, True],
641641
})
642-
conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size)
642+
conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
643643
if push_to_hub:
644644
conversational_unpaired_preference_dataset.push_to_hub(repo_id, config_name="conversational_unpaired_preference")
645645
# fmt: on

tests/test_bco_trainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def test_tokenize_and_process_tokens(self):
160160
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
161161
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
162162
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
163-
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
164-
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
165-
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
166-
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])
163+
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
164+
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
165+
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
166+
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])
167167

168168
fn_kwargs = {
169169
"prefix": "",
@@ -178,13 +178,15 @@ def test_tokenize_and_process_tokens(self):
178178
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
179179
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
180180
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
181-
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
182-
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
181+
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
182+
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
183183
self.assertListEqual(
184-
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
184+
processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]
185185
)
186186
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
187-
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])
187+
self.assertListEqual(
188+
processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]
189+
)
188190

189191
@require_sklearn
190192
def test_bco_trainer_without_providing_ref_model(self):

tests/test_kto_trainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ def test_tokenize_and_process_tokens(self):
156156
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
157157
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
158158
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
159-
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
160-
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
161-
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
162-
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])
159+
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
160+
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
161+
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
162+
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])
163163

164164
# Test corruption of (prompt, completion) pairs for KL dataset
165165
for batch_size in [2, 3]:
@@ -196,13 +196,15 @@ def test_tokenize_and_process_tokens(self):
196196
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
197197
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
198198
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
199-
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
200-
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
199+
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
200+
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
201201
self.assertListEqual(
202-
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
202+
processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]
203203
)
204204
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
205-
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])
205+
self.assertListEqual(
206+
processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]
207+
)
206208

207209
def test_kto_trainer_without_providing_ref_model(self):
208210
with tempfile.TemporaryDirectory() as tmp_dir:

tests/test_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def test_sft_trainer_eval_packing(self):
11721172
)
11731173

11741174
self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
1175-
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 5) # w/ this dataset, we end up with 5 seqs
1175+
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs
11761176

11771177
def test_sft_trainer_no_packing(self):
11781178
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)