|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import gc
|
| 6 | +import random |
6 | 7 | from functools import wraps
|
7 |
| -from typing import Any, Dict, Union |
| 8 | +from typing import Any, Dict, Optional, Union |
8 | 9 |
|
| 10 | +import pandas as pd |
9 | 11 | import torch
|
| 12 | +import wandb |
| 13 | +from accelerate import PartialState |
| 14 | +from datasets import Dataset, IterableDataset |
10 | 15 | from peft.optimizers import create_loraplus_optimizer
|
11 | 16 | from torch import nn
|
12 |
| -from transformers import Trainer |
| 17 | +from torch.utils.data import DataLoader |
| 18 | +from transformers import ( |
| 19 | + BaseImageProcessor, |
| 20 | + FeatureExtractionMixin, |
| 21 | + PreTrainedTokenizerBase, |
| 22 | + ProcessorMixin, |
| 23 | + Trainer, |
| 24 | +) |
| 25 | +from transformers.trainer_utils import EvalLoopOutput |
13 | 26 | from transformers.utils import is_sagemaker_mp_enabled
|
14 |
| -from trl import DPOTrainer |
| 27 | +from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt |
| 28 | +from trl.trainer.utils import log_table_to_comet_experiment |
15 | 29 |
|
16 | 30 | from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
17 | 31 | from axolotl.core.trainers.utils import (
|
@@ -81,6 +95,64 @@ def push_to_hub(self, *args, **kwargs) -> str:
|
81 | 95 |
|
82 | 96 | return super().push_to_hub(*args, **kwargs)
|
83 | 97 |
|
| 98 | + # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release |
| 99 | + def _prepare_dataset( |
| 100 | + self, |
| 101 | + dataset: Union[Dataset, IterableDataset], |
| 102 | + processing_class: Union[ |
| 103 | + PreTrainedTokenizerBase, |
| 104 | + BaseImageProcessor, |
| 105 | + FeatureExtractionMixin, |
| 106 | + ProcessorMixin, |
| 107 | + ], |
| 108 | + args: DPOConfig, |
| 109 | + dataset_name: str, |
| 110 | + ) -> Union[Dataset, IterableDataset]: |
| 111 | + # Build the kwargs for the `map` function |
| 112 | + map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} |
| 113 | + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc |
| 114 | + map_kwargs["num_proc"] = args.dataset_num_proc |
| 115 | + |
| 116 | + with PartialState().main_process_first(): |
| 117 | + # Extract prompt if needed |
| 118 | + if isinstance( |
| 119 | + dataset, Dataset |
| 120 | + ): # `IterableDataset.map` does not support `desc` |
| 121 | + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" |
| 122 | + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) |
| 123 | + |
| 124 | + # Apply the chat template if needed |
| 125 | + if isinstance( |
| 126 | + dataset, Dataset |
| 127 | + ): # `IterableDataset.map` does not support `desc` |
| 128 | + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" |
| 129 | + dataset = dataset.map( |
| 130 | + maybe_apply_chat_template, |
| 131 | + fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, |
| 132 | + **map_kwargs, |
| 133 | + ) |
| 134 | + |
| 135 | + # Tokenize the dataset |
| 136 | + if isinstance( |
| 137 | + dataset, Dataset |
| 138 | + ): # `IterableDataset.map` does not support `desc` |
| 139 | + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" |
| 140 | + |
| 141 | + dataset = dataset.map( |
| 142 | + self.tokenize_row if not self.is_vision_model else self.process_row, |
| 143 | + remove_columns=["chosen", "rejected"], |
| 144 | + fn_kwargs={ |
| 145 | + "processing_class": processing_class, |
| 146 | + "max_prompt_length": args.max_prompt_length, |
| 147 | + "max_completion_length": args.max_completion_length, |
| 148 | + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) |
| 149 | + "add_special_tokens": False, |
| 150 | + }, |
| 151 | + **map_kwargs, |
| 152 | + ) |
| 153 | + |
| 154 | + return dataset |
| 155 | + |
84 | 156 | @staticmethod
|
85 | 157 | def tokenize_row(
|
86 | 158 | features,
|
@@ -124,3 +196,67 @@ def training_step(
|
124 | 196 | gc.collect()
|
125 | 197 | torch.cuda.empty_cache()
|
126 | 198 | return loss
|
| 199 | + |
| 200 | + # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release |
| 201 | + def evaluation_loop( |
| 202 | + self, |
| 203 | + dataloader: DataLoader, |
| 204 | + description: str, |
| 205 | + prediction_loss_only: Optional[bool] = None, |
| 206 | + ignore_keys: Optional[list[str]] = None, |
| 207 | + metric_key_prefix: str = "eval", |
| 208 | + ) -> EvalLoopOutput: |
| 209 | + """ |
| 210 | + Overriding built-in evaluation loop to store metrics for each batch. |
| 211 | + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. |
| 212 | +
|
| 213 | + Works both with or without labels. |
| 214 | + """ |
| 215 | + |
| 216 | + # Sample and save to game log if requested (for one batch to save time) |
| 217 | + if self.generate_during_eval: |
| 218 | + # Generate random indices within the range of the total number of samples |
| 219 | + num_samples = len(dataloader.dataset) |
| 220 | + random_indices = random.sample( |
| 221 | + range(num_samples), k=self.args.eval_batch_size |
| 222 | + ) |
| 223 | + |
| 224 | + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader |
| 225 | + random_batch_dataset = dataloader.dataset.select(random_indices) |
| 226 | + random_batch = self.data_collator(random_batch_dataset) |
| 227 | + random_batch = self._prepare_inputs(random_batch) |
| 228 | + |
| 229 | + policy_output_decoded, ref_output_decoded = ( |
| 230 | + self.generate_from_model_and_ref(self.model, random_batch) |
| 231 | + ) |
| 232 | + |
| 233 | + table = pd.DataFrame( |
| 234 | + columns=["Prompt", "Policy", "Ref Model"], |
| 235 | + data=[ |
| 236 | + [prompt, pol[len(prompt) :], ref[len(prompt) :]] |
| 237 | + for prompt, pol, ref in zip( |
| 238 | + random_batch_dataset["prompt"], |
| 239 | + policy_output_decoded, |
| 240 | + ref_output_decoded, |
| 241 | + ) |
| 242 | + ], |
| 243 | + ) |
| 244 | + if "wandb" in self.args.report_to and self.accelerator.is_main_process: |
| 245 | + wandb.log({"game_log": wandb.Table(data=table)}) |
| 246 | + |
| 247 | + if "comet_ml" in self.args.report_to: |
| 248 | + log_table_to_comet_experiment( |
| 249 | + name="game_log.csv", |
| 250 | + table=table, |
| 251 | + ) |
| 252 | + |
| 253 | + # Base evaluation |
| 254 | + initial_output = super().evaluation_loop( |
| 255 | + dataloader, |
| 256 | + description, |
| 257 | + prediction_loss_only, |
| 258 | + ignore_keys, |
| 259 | + metric_key_prefix, |
| 260 | + ) |
| 261 | + |
| 262 | + return initial_output |
0 commit comments