|
19 | 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 | 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 | 21 | # SOFTWARE.
|
22 |
| - |
23 | 22 | # ruff: noqa: C901
|
24 | 23 | import logging
|
25 | 24 | import os
|
26 | 25 | import time
|
27 |
| -from typing import List, Optional, Tuple, Type, Union |
| 26 | +from dataclasses import dataclass |
| 27 | +from typing import Dict, List, Optional, Tuple, Type, Union |
28 | 28 |
|
29 | 29 | import torch
|
30 | 30 | import torch.nn.functional as F
|
31 | 31 | import transformers
|
32 | 32 | from datasets.download.streaming_download_manager import xPath
|
| 33 | +from pydantic import BaseModel |
33 | 34 | from torch.utils.data import DataLoader
|
34 | 35 | from torch.utils.data.distributed import DistributedSampler
|
35 | 36 | from tqdm import tqdm
|
36 | 37 | from transformers import AutoTokenizer, BatchEncoding
|
37 | 38 |
|
38 |
| -from lighteval.config.lighteval_config import FullNanotronConfig |
39 | 39 | from lighteval.data import (
|
40 | 40 | GenDistributedSampler,
|
41 | 41 | GenerativeTaskDatasetNanotron,
|
|
69 | 69 | if is_nanotron_available():
|
70 | 70 | from nanotron import distributed as dist
|
71 | 71 | from nanotron import logging
|
| 72 | + from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs |
| 73 | + from nanotron.config.parallelism_config import ParallelismArgs |
72 | 74 | from nanotron.generation.decode import decode_tokenized
|
| 75 | + from nanotron.generation.sampler import SamplerType |
73 | 76 | from nanotron.logging import human_format, log_rank
|
74 | 77 | from nanotron.models import build_model
|
75 | 78 | from nanotron.parallel.context import ParallelContext
|
|
83 | 86 |
|
84 | 87 | logger = logging.get_logger(__name__)
|
85 | 88 |
|
| 89 | +DEFAULT_GENERATION_SEED = 42 |
| 90 | + |
| 91 | + |
| 92 | +class GenerationArgs(BaseModel): |
| 93 | + sampler: Optional["SamplerType"] = None |
| 94 | + temperature: Optional[float] = None |
| 95 | + top_k: Optional[int] = None |
| 96 | + top_p: Optional[float] = None |
| 97 | + n_samples: Optional[int] = None |
| 98 | + eos: Optional[str] = None |
| 99 | + seed: Optional[int] = None |
| 100 | + use_cache: Optional[bool] = False |
| 101 | + |
| 102 | + def __post_init__(self): |
| 103 | + if self.seed is None: |
| 104 | + self.seed = DEFAULT_GENERATION_SEED |
| 105 | + |
| 106 | + |
| 107 | +@dataclass |
| 108 | +class LightEvalLoggingArgs: |
| 109 | + """Arguments related to logging for LightEval""" |
| 110 | + |
| 111 | + output_dir: str |
| 112 | + results_path_template: str | None = None |
| 113 | + save_details: bool = True |
| 114 | + push_to_hub: bool = False |
| 115 | + push_to_tensorboard: bool = False |
| 116 | + public_run: bool = False |
| 117 | + results_org: str | None = None |
| 118 | + tensorboard_metric_prefix: str = "eval" |
| 119 | + |
| 120 | + |
| 121 | +@dataclass |
| 122 | +class LightEvalTasksArgs: |
| 123 | + """Arguments related to tasks for LightEval""" |
| 124 | + |
| 125 | + tasks: str |
| 126 | + custom_tasks: Optional[str] = None |
| 127 | + max_samples: Optional[int] = None |
| 128 | + num_fewshot_seeds: Optional[int] = None |
| 129 | + |
| 130 | + dataset_loading_processes: int = 8 |
| 131 | + multichoice_continuations_start_space: Optional[bool] = None |
| 132 | + pairwise_tokenization: bool = False |
| 133 | + |
| 134 | + |
| 135 | +@dataclass |
| 136 | +class LightEvalConfig: |
| 137 | + """Arguments related to running LightEval on checkpoints. |
| 138 | +
|
| 139 | + All is optional because you can also use this class to later supply arguments to override |
| 140 | + the saved config when running LightEval after training. |
| 141 | + """ |
| 142 | + |
| 143 | + logging: LightEvalLoggingArgs |
| 144 | + tasks: LightEvalTasksArgs |
| 145 | + parallelism: "ParallelismArgs" |
| 146 | + batch_size: int = 0 |
| 147 | + generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None |
| 148 | + |
| 149 | + |
| 150 | +@dataclass |
| 151 | +class FullNanotronConfig: |
| 152 | + lighteval_config: LightEvalConfig |
| 153 | + nanotron_model: "ModelArgs" |
| 154 | + nanotron_tokenizer: "TokenizerArgs" |
| 155 | + nanotron_general: "GeneralArgs" |
| 156 | + |
| 157 | + @property |
| 158 | + def generation_parameters(self): |
| 159 | + # Return the generation parameters from the lighteval config |
| 160 | + # or create default generation parameters if none are set |
| 161 | + if self.lighteval_config.generation: |
| 162 | + return self.lighteval_config.generation |
| 163 | + return GenerationArgs() |
| 164 | + |
86 | 165 |
|
87 | 166 | class NanotronLightevalModel(LightevalModel):
|
88 | 167 | # Default max sequence length setting for when no `max_length` is provided
|
|
0 commit comments