Skip to content

Commit c513487

Browse files
authored
support val_set_size for splitting test split from train with DPO (axolotl-ai-cloud#2572)
1 parent dda95e6 commit c513487

File tree

2 files changed

+170
-4
lines changed

2 files changed

+170
-4
lines changed

src/axolotl/core/trainers/dpo/trainer.py

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,29 @@
33
"""
44

55
import gc
6+
import random
67
from functools import wraps
7-
from typing import Any, Dict, Union
8+
from typing import Any, Dict, Optional, Union
89

10+
import pandas as pd
911
import torch
12+
import wandb
13+
from accelerate import PartialState
14+
from datasets import Dataset, IterableDataset
1015
from peft.optimizers import create_loraplus_optimizer
1116
from torch import nn
12-
from transformers import Trainer
17+
from torch.utils.data import DataLoader
18+
from transformers import (
19+
BaseImageProcessor,
20+
FeatureExtractionMixin,
21+
PreTrainedTokenizerBase,
22+
ProcessorMixin,
23+
Trainer,
24+
)
25+
from transformers.trainer_utils import EvalLoopOutput
1326
from transformers.utils import is_sagemaker_mp_enabled
14-
from trl import DPOTrainer
27+
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
28+
from trl.trainer.utils import log_table_to_comet_experiment
1529

1630
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
1731
from axolotl.core.trainers.utils import (
@@ -81,6 +95,64 @@ def push_to_hub(self, *args, **kwargs) -> str:
8195

8296
return super().push_to_hub(*args, **kwargs)
8397

98+
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
99+
def _prepare_dataset(
100+
self,
101+
dataset: Union[Dataset, IterableDataset],
102+
processing_class: Union[
103+
PreTrainedTokenizerBase,
104+
BaseImageProcessor,
105+
FeatureExtractionMixin,
106+
ProcessorMixin,
107+
],
108+
args: DPOConfig,
109+
dataset_name: str,
110+
) -> Union[Dataset, IterableDataset]:
111+
# Build the kwargs for the `map` function
112+
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
113+
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
114+
map_kwargs["num_proc"] = args.dataset_num_proc
115+
116+
with PartialState().main_process_first():
117+
# Extract prompt if needed
118+
if isinstance(
119+
dataset, Dataset
120+
): # `IterableDataset.map` does not support `desc`
121+
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
122+
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
123+
124+
# Apply the chat template if needed
125+
if isinstance(
126+
dataset, Dataset
127+
): # `IterableDataset.map` does not support `desc`
128+
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
129+
dataset = dataset.map(
130+
maybe_apply_chat_template,
131+
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
132+
**map_kwargs,
133+
)
134+
135+
# Tokenize the dataset
136+
if isinstance(
137+
dataset, Dataset
138+
): # `IterableDataset.map` does not support `desc`
139+
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
140+
141+
dataset = dataset.map(
142+
self.tokenize_row if not self.is_vision_model else self.process_row,
143+
remove_columns=["chosen", "rejected"],
144+
fn_kwargs={
145+
"processing_class": processing_class,
146+
"max_prompt_length": args.max_prompt_length,
147+
"max_completion_length": args.max_completion_length,
148+
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
149+
"add_special_tokens": False,
150+
},
151+
**map_kwargs,
152+
)
153+
154+
return dataset
155+
84156
@staticmethod
85157
def tokenize_row(
86158
features,
@@ -124,3 +196,67 @@ def training_step(
124196
gc.collect()
125197
torch.cuda.empty_cache()
126198
return loss
199+
200+
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
201+
def evaluation_loop(
202+
self,
203+
dataloader: DataLoader,
204+
description: str,
205+
prediction_loss_only: Optional[bool] = None,
206+
ignore_keys: Optional[list[str]] = None,
207+
metric_key_prefix: str = "eval",
208+
) -> EvalLoopOutput:
209+
"""
210+
Overriding built-in evaluation loop to store metrics for each batch.
211+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
212+
213+
Works both with or without labels.
214+
"""
215+
216+
# Sample and save to game log if requested (for one batch to save time)
217+
if self.generate_during_eval:
218+
# Generate random indices within the range of the total number of samples
219+
num_samples = len(dataloader.dataset)
220+
random_indices = random.sample(
221+
range(num_samples), k=self.args.eval_batch_size
222+
)
223+
224+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
225+
random_batch_dataset = dataloader.dataset.select(random_indices)
226+
random_batch = self.data_collator(random_batch_dataset)
227+
random_batch = self._prepare_inputs(random_batch)
228+
229+
policy_output_decoded, ref_output_decoded = (
230+
self.generate_from_model_and_ref(self.model, random_batch)
231+
)
232+
233+
table = pd.DataFrame(
234+
columns=["Prompt", "Policy", "Ref Model"],
235+
data=[
236+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
237+
for prompt, pol, ref in zip(
238+
random_batch_dataset["prompt"],
239+
policy_output_decoded,
240+
ref_output_decoded,
241+
)
242+
],
243+
)
244+
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
245+
wandb.log({"game_log": wandb.Table(data=table)})
246+
247+
if "comet_ml" in self.args.report_to:
248+
log_table_to_comet_experiment(
249+
name="game_log.csv",
250+
table=table,
251+
)
252+
253+
# Base evaluation
254+
initial_output = super().evaluation_loop(
255+
dataloader,
256+
description,
257+
prediction_loss_only,
258+
ignore_keys,
259+
metric_key_prefix,
260+
)
261+
262+
return initial_output

src/axolotl/utils/data/rl.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,37 @@ def load_split(dataset_cfgs, _cfg):
204204
else:
205205
eval_dataset = load_split(cfg.test_datasets, cfg)
206206
if not eval_dataset:
207-
eval_dataset = None
207+
if cfg.val_set_size:
208+
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
209+
to_hash_train = (
210+
train_dataset._fingerprint # pylint: disable=protected-access
211+
+ "|"
212+
+ str(cfg.val_set_size)
213+
+ "|"
214+
+ "train"
215+
+ "|"
216+
+ str(cfg.seed or 42)
217+
)
218+
to_hash_test = (
219+
train_dataset._fingerprint # pylint: disable=protected-access
220+
+ "|"
221+
+ str(cfg.val_set_size)
222+
+ "|"
223+
+ "test"
224+
+ "|"
225+
+ str(cfg.seed or 42)
226+
)
227+
train_fingerprint = md5(to_hash_train)
228+
test_fingerprint = md5(to_hash_test)
229+
ds_w_test_split = train_dataset.train_test_split(
230+
test_size=cfg.val_set_size,
231+
seed=cfg.seed,
232+
shuffle=False,
233+
train_new_fingerprint=train_fingerprint,
234+
test_new_fingerprint=test_fingerprint,
235+
)
236+
eval_dataset = ds_w_test_split["test"]
237+
train_dataset = ds_w_test_split["train"]
208238

209239
if not train_is_preprocessed:
210240
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)

0 commit comments

Comments
 (0)