Skip to content

Commit 8aaf51c

Browse files
thomwolfNathanHBclefourrier
authored
Support for nanotron (#11)
Support for Nanotron models --- Co-authored-by: Nathan Habib <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 1e837a9 commit 8aaf51c

16 files changed

+487
-471
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ repos:
3737
rev: 'v0.1.6'
3838
hooks:
3939
- id: ruff
40+
args: ['--fix']
4041
- id: ruff-format

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ It is still an early, internal version - it should be nice to use but don't expe
1111
In case of problems or question, feel free to open an issue!
1212

1313
## How to install and use
14-
### Requirements
14+
### Installation
1515
0) Create your virtual environment using virtualenv or conda depending on your preferences. We require Python3.10
1616

1717
1) Clone the package using `git clone`, then `cd lighteval-harness`, `pip install -e .` Once the dependencies are installed, `cd src`.
@@ -22,6 +22,12 @@ Optional:
2222

2323
2) Add your user token to the environment variable `HUGGING_FACE_HUB_TOKEN` if you want to push your results to the hub
2424

25+
For the linting:
26+
```bash
27+
pre-commit install
28+
pre-commit run --config .pre-commit-config.yaml --all-files
29+
```
30+
2531

2632
### Usage
2733
- Launching on CPU

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ optimum = ["optimum==1.12.0"]
8282
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
8383
adapters = ["peft==0.3.0"]
8484
nanotron = [
85-
"nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e",
86-
"brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b",
85+
"nanotron@git+https://github.com/huggingface/nanotron",
8786
"tensorboardX"
8887
]
8988

run_evals_accelerate.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import argparse
2+
3+
from lighteval.main_accelerate import CACHE_DIR, main
4+
5+
6+
def get_parser():
7+
parser = argparse.ArgumentParser()
8+
group = parser.add_mutually_exclusive_group(required=True)
9+
task_type_group = parser.add_mutually_exclusive_group(required=True)
10+
11+
# Model type 1) Base model
12+
weight_type_group = parser.add_mutually_exclusive_group()
13+
weight_type_group.add_argument(
14+
"--delta_weights",
15+
action="store_true",
16+
default=False,
17+
help="set to True of your model should be merged with a base model, also need to provide the base model name",
18+
)
19+
weight_type_group.add_argument(
20+
"--adapter_weights",
21+
action="store_true",
22+
default=False,
23+
help="set to True of your model has been trained with peft, also need to provide the base model name",
24+
)
25+
parser.add_argument(
26+
"--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights"
27+
)
28+
29+
task_type_group.add_argument("--model_args")
30+
parser.add_argument("--model_dtype", type=str, default=None)
31+
parser.add_argument(
32+
"--multichoice_continuations_start_space",
33+
action="store_true",
34+
help="Whether to force multiple choice continuations to start with a space",
35+
)
36+
parser.add_argument(
37+
"--no_multichoice_continuations_start_space",
38+
action="store_true",
39+
help="Whether to force multiple choice continuations to not start with a space",
40+
)
41+
parser.add_argument("--use_chat_template", default=False, action="store_true")
42+
# Model type 2) TGI
43+
task_type_group.add_argument("--inference_server_address", type=str)
44+
parser.add_argument("--inference_server_auth", type=str, default=None)
45+
# Model type 3) Inference endpoints
46+
task_type_group.add_argument("--endpoint_model_name", type=str)
47+
parser.add_argument("--accelerator", type=str, default=None)
48+
parser.add_argument("--vendor", type=str, default=None)
49+
parser.add_argument("--region", type=str, default=None)
50+
parser.add_argument("--instance_size", type=str, default=None)
51+
parser.add_argument("--instance_type", type=str, default=None)
52+
parser.add_argument("--reuse_existing", default=False, action="store_true")
53+
# Debug
54+
parser.add_argument("--max_samples", type=int, default=None)
55+
parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="")
56+
# Saving
57+
parser.add_argument("--push_results_to_hub", default=False, action="store_true")
58+
parser.add_argument("--save_details", action="store_true")
59+
parser.add_argument("--push_details_to_hub", default=False, action="store_true")
60+
parser.add_argument(
61+
"--public_run", default=False, action="store_true", help="Push results and details to a public repo"
62+
)
63+
parser.add_argument("--cache_dir", type=str, default=CACHE_DIR)
64+
parser.add_argument(
65+
"--results_org",
66+
type=str,
67+
help="Hub organisation where you want to store the results. Your current token must have write access to it",
68+
)
69+
# Common parameters
70+
parser.add_argument("--output_dir", required=True)
71+
parser.add_argument("--override_batch_size", type=int, default=-1)
72+
parser.add_argument("--dataset_loading_processes", type=int, default=1)
73+
parser.add_argument(
74+
"--custom_tasks_file",
75+
type=str,
76+
default=None,
77+
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",
78+
)
79+
group.add_argument(
80+
"--tasks",
81+
type=str,
82+
default=None,
83+
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
84+
)
85+
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
86+
return parser
87+
88+
89+
if __name__ == "__main__":
90+
parser = get_parser()
91+
args, unknowns = parser.parse_known_args()
92+
main(args)

run_evals_nanotron.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# flake8: noqa: C901
2+
import argparse
3+
4+
from lighteval.main_nanotron import main
5+
6+
7+
def get_parser():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument(
10+
"--checkpoint-config-path",
11+
type=str,
12+
required=True,
13+
help="Path to the brr checkpoint YAML or python config file, potentially on S3",
14+
)
15+
parser.add_argument(
16+
"--lighteval-override",
17+
type=str,
18+
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
19+
)
20+
parser.add_argument(
21+
"--cache-dir",
22+
type=str,
23+
default="",
24+
help="Cache directory",
25+
)
26+
27+
return parser
28+
29+
30+
if __name__ == "__main__":
31+
parser = get_parser()
32+
args, unknowns = parser.parse_known_args()
33+
main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir)

src/lighteval/data.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,37 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR
198198
return -(len(toks) + gen_length)
199199

200200

201+
class GenerativeTaskDatasetNanotron(DynamicBatchDataset):
202+
def __getitem__(self, index) -> Request:
203+
"""
204+
Get an item from the dataset depending on the split we are currently in.
205+
For instance, if we are in split 0, we will get the item at index 0, if
206+
we are in split 1, we will get the item at index self.split_size, etc.
207+
Used for dynamic batching.
208+
209+
Args:
210+
index (int): The index of the item.
211+
212+
Returns:
213+
Any: The item at the specified index.
214+
"""
215+
return index, self.sorted_data[index + self.split_start]
216+
217+
def _sorting_criteria(self, request) -> int:
218+
"""
219+
Collate function for generating batches.
220+
221+
Args:
222+
x (Any): The input data.
223+
224+
Returns:
225+
Any: The collated data.
226+
"""
227+
toks = request.tokenized_context
228+
gen_length = request.generation_size
229+
return -(len(toks) + gen_length)
230+
231+
201232
class GenDistributedSampler(DistributedSampler):
202233
"""A distributed sampler that copy the last element only when drop_last is False so we keep a small padding in the batches
203234
as our samples are sorted by length.

src/lighteval/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import copy
66
from typing import Dict, Union
77

8+
from pytablewriter import LatexTableWriter, MarkdownTableWriter
9+
810
from lighteval.logging.evaluation_tracker import EvaluationTracker
911
from lighteval.logging.hierarchical_logger import hlog
1012
from lighteval.models.base_model import BaseModel
@@ -99,8 +101,6 @@ def evaluate( # noqa: C901
99101

100102
def make_results_table(result_dict):
101103
"""Generate table of results."""
102-
from pytablewriter import LatexTableWriter, MarkdownTableWriter
103-
104104
md_writer = MarkdownTableWriter()
105105
latex_writer = LatexTableWriter()
106106
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

src/lighteval/logging/evaluation_tracker.py

Lines changed: 74 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
TaskConfigLogger,
1919
VersionsLogger,
2020
)
21-
from lighteval.utils import is_nanotron_available
21+
from lighteval.utils import is_nanotron_available, obj_to_markdown
2222

2323

2424
if is_nanotron_available():
25-
from brrr.config import BrrrConfig
26-
from brrr.experiment_loggers import obj_to_markdown
27-
from nanotron.config import get_config_from_dict
25+
from nanotron.config import Config, get_config_from_dict
2826

2927

3028
class EnhancedJSONEncoder(json.JSONEncoder):
@@ -104,81 +102,81 @@ def save(
104102
105103
"""
106104
hlog("Saving experiment tracker")
107-
try:
108-
date_id = datetime.now().isoformat().replace(":", "-")
109-
110-
output_dir_results = Path(output_dir) / "results" / self.general_config_logger.model_name
111-
output_dir_details = Path(output_dir) / "details" / self.general_config_logger.model_name
112-
output_dir_details_sub_folder = output_dir_details / date_id
113-
output_dir_results.mkdir(parents=True, exist_ok=True)
114-
output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True)
115-
116-
output_results_file = output_dir_results / f"results_{date_id}.json"
117-
output_results_in_details_file = output_dir_details / f"results_{date_id}.json"
118-
119-
hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}")
120-
121-
to_dump = {
122-
"config_general": asdict(self.general_config_logger),
123-
"results": self.metrics_logger.metric_aggregated,
124-
"versions": self.versions_logger.versions,
125-
"config_tasks": self.task_config_logger.tasks_configs,
126-
"summary_tasks": self.details_logger.compiled_details,
127-
"summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
128-
}
129-
dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2)
130-
131-
with open(output_results_file, "w") as f:
132-
f.write(dumped)
133-
134-
with open(output_results_in_details_file, "w") as f:
135-
f.write(dumped)
136-
137-
for task_name, task_details in self.details_logger.details.items():
138-
output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet"
139-
# Create a dataset from the dictionary
140-
try:
141-
dataset = Dataset.from_list([asdict(detail) for detail in task_details])
142-
except Exception:
143-
# We force cast to str to avoid formatting problems for nested objects
144-
dataset = Dataset.from_list(
145-
[{k: str(v) for k, v in asdict(detail).items()} for detail in task_details]
146-
)
105+
# try:
106+
date_id = datetime.now().isoformat().replace(":", "-")
147107

148-
# We don't keep 'id' around if it's there
149-
column_names = dataset.column_names
150-
if "id" in dataset.column_names:
151-
column_names = [t for t in dataset.column_names if t != "id"]
152-
153-
# Sort column names to make it easier later
154-
dataset = dataset.select_columns(sorted(column_names))
155-
# Save the dataset to a Parquet file
156-
dataset.to_parquet(output_file_details.as_posix())
157-
158-
if push_results_to_hub:
159-
self.api.upload_folder(
160-
repo_id=self.hub_results_repo if public else self.hub_private_results_repo,
161-
folder_path=output_dir_results,
162-
path_in_repo=self.general_config_logger.model_name,
163-
repo_type="dataset",
164-
commit_message=f"Updating model {self.general_config_logger.model_name}",
165-
)
108+
output_dir_results = Path(output_dir) / "results" / self.general_config_logger.model_name
109+
output_dir_details = Path(output_dir) / "details" / self.general_config_logger.model_name
110+
output_dir_details_sub_folder = output_dir_details / date_id
111+
output_dir_results.mkdir(parents=True, exist_ok=True)
112+
output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True)
166113

167-
if push_details_to_hub:
168-
self.details_to_hub(
169-
model_name=self.general_config_logger.model_name,
170-
results_file_path=output_results_in_details_file,
171-
details_folder_path=output_dir_details_sub_folder,
172-
push_as_public=public,
173-
)
114+
output_results_file = output_dir_results / f"results_{date_id}.json"
115+
output_results_in_details_file = output_dir_details / f"results_{date_id}.json"
116+
117+
hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}")
174118

175-
if push_results_to_tensorboard:
176-
self.push_results_to_tensorboard(
177-
results=self.metrics_logger.metric_aggregated, details=self.details_logger.details
119+
to_dump = {
120+
"config_general": asdict(self.general_config_logger),
121+
"results": self.metrics_logger.metric_aggregated,
122+
"versions": self.versions_logger.versions,
123+
"config_tasks": self.task_config_logger.tasks_configs,
124+
"summary_tasks": self.details_logger.compiled_details,
125+
"summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
126+
}
127+
dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2)
128+
129+
with open(output_results_file, "w") as f:
130+
f.write(dumped)
131+
132+
with open(output_results_in_details_file, "w") as f:
133+
f.write(dumped)
134+
135+
for task_name, task_details in self.details_logger.details.items():
136+
output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet"
137+
# Create a dataset from the dictionary
138+
try:
139+
dataset = Dataset.from_list([asdict(detail) for detail in task_details])
140+
except Exception:
141+
# We force cast to str to avoid formatting problems for nested objects
142+
dataset = Dataset.from_list(
143+
[{k: str(v) for k, v in asdict(detail).items()} for detail in task_details]
178144
)
179-
except Exception as e:
180-
hlog("WARNING: Could not save results")
181-
hlog(repr(e))
145+
146+
# We don't keep 'id' around if it's there
147+
column_names = dataset.column_names
148+
if "id" in dataset.column_names:
149+
column_names = [t for t in dataset.column_names if t != "id"]
150+
151+
# Sort column names to make it easier later
152+
dataset = dataset.select_columns(sorted(column_names))
153+
# Save the dataset to a Parquet file
154+
dataset.to_parquet(output_file_details.as_posix())
155+
156+
if push_results_to_hub:
157+
self.api.upload_folder(
158+
repo_id=self.hub_results_repo if public else self.hub_private_results_repo,
159+
folder_path=output_dir_results,
160+
path_in_repo=self.general_config_logger.model_name,
161+
repo_type="dataset",
162+
commit_message=f"Updating model {self.general_config_logger.model_name}",
163+
)
164+
165+
if push_details_to_hub:
166+
self.details_to_hub(
167+
model_name=self.general_config_logger.model_name,
168+
results_file_path=output_results_in_details_file,
169+
details_folder_path=output_dir_details_sub_folder,
170+
push_as_public=public,
171+
)
172+
173+
if push_results_to_tensorboard:
174+
self.push_results_to_tensorboard(
175+
results=self.metrics_logger.metric_aggregated, details=self.details_logger.details
176+
)
177+
# except Exception as e:
178+
# hlog("WARNING: Could not save results")
179+
# hlog(repr(e))
182180

183181
def generate_final_dict(self) -> dict:
184182
"""Aggregates and returns all the logger's experiment information in a dictionary.
@@ -487,7 +485,7 @@ def push_results_to_tensorboard( # noqa: C901
487485
if not is_nanotron_available():
488486
hlog_warn("You cannot push results to tensorboard with having nanotron installed. Skipping")
489487
return
490-
config: BrrrConfig = get_config_from_dict(self.general_config_logger.config, config_class=BrrrConfig)
488+
config: Config = get_config_from_dict(self.general_config_logger.config, config_class=Config)
491489
lighteval_config = config.lighteval
492490
try:
493491
global_step = config.general.step

0 commit comments

Comments
 (0)