23
23
# flake8: noqa: C901
24
24
import os
25
25
26
+ import yaml
26
27
from typer import Option
27
28
from typing_extensions import Annotated
29
+ from yaml import SafeLoader
28
30
29
31
30
- CACHE_DIR : str = os .getenv ("HF_HOME" , "/scratch" )
31
-
32
32
HELP_PANEL_NAME_1 = "Common Parameters"
33
33
HELP_PANEL_NAME_2 = "Logging Parameters"
34
34
HELP_PANEL_NAME_3 = "Debug Parameters"
@@ -43,41 +43,43 @@ def nanotron(
43
43
str , Option (help = "Path to the nanotron checkpoint YAML or python config file, potentially on s3." )
44
44
],
45
45
lighteval_config_path : Annotated [str , Option (help = "Path to a YAML config to be used for the evaluation." )],
46
- cache_dir : Annotated [str , Option (help = "Cache directory for datasets and models." )] = CACHE_DIR ,
47
46
):
48
47
"""
49
48
Evaluate models using nanotron as backend.
50
49
"""
51
- from nanotron .config import Config , get_config_from_file
50
+ from nanotron .config import GeneralArgs , ModelArgs , TokenizerArgs , get_config_from_dict , get_config_from_file
52
51
53
- from lighteval .config .lighteval_config import FullNanotronConfig , LightEvalConfig
52
+ from lighteval .config .lighteval_config import (
53
+ FullNanotronConfig ,
54
+ LightEvalConfig ,
55
+ )
54
56
from lighteval .logging .evaluation_tracker import EvaluationTracker
55
- from lighteval .logging .hierarchical_logger import htrack_block
56
57
from lighteval .pipeline import ParallelismManager , Pipeline , PipelineParameters
57
58
from lighteval .utils .imports import NO_NANOTRON_ERROR_MSG , is_nanotron_available
58
- from lighteval .utils .utils import EnvConfig
59
-
60
- env_config = EnvConfig (token = os .getenv ("HF_TOKEN" ), cache_dir = cache_dir )
61
59
62
60
if not is_nanotron_available ():
63
61
raise ImportError (NO_NANOTRON_ERROR_MSG )
64
62
65
- with htrack_block ("Load nanotron config" ):
66
- # Create nanotron config
67
- if not checkpoint_config_path .endswith (".yaml" ):
68
- raise ValueError ("The checkpoint path should point to a YAML file" )
63
+ # Create nanotron config
64
+ if not checkpoint_config_path .endswith (".yaml" ):
65
+ raise ValueError ("The checkpoint path should point to a YAML file" )
66
+
67
+ with open (checkpoint_config_path ) as f :
68
+ nanotron_yaml = yaml .load (f , Loader = SafeLoader )
69
69
70
- model_config = get_config_from_file (
71
- checkpoint_config_path ,
72
- config_class = Config ,
73
- model_config_class = None ,
70
+ model_config , tokenizer_config , general_config = [
71
+ get_config_from_dict (
72
+ nanotron_yaml [ key ] ,
73
+ config_class = config_class ,
74
74
skip_unused_config_keys = True ,
75
75
skip_null_keys = True ,
76
76
)
77
+ for key , config_class in [("model" , ModelArgs ), ("tokenizer" , TokenizerArgs ), ("general" , GeneralArgs )]
78
+ ]
77
79
78
- # We are getting an type error, because the get_config_from_file is not correctly typed,
79
- lighteval_config : LightEvalConfig = get_config_from_file (lighteval_config_path , config_class = LightEvalConfig ) # type: ignore
80
- nanotron_config = FullNanotronConfig (lighteval_config , model_config )
80
+ # Load lighteval config
81
+ lighteval_config : LightEvalConfig = get_config_from_file (lighteval_config_path , config_class = LightEvalConfig ) # type: ignore
82
+ nanotron_config = FullNanotronConfig (lighteval_config , model_config , tokenizer_config , general_config )
81
83
82
84
evaluation_tracker = EvaluationTracker (
83
85
output_dir = lighteval_config .logging .output_dir ,
@@ -88,17 +90,15 @@ def nanotron(
88
90
push_to_tensorboard = lighteval_config .logging .push_to_tensorboard ,
89
91
save_details = lighteval_config .logging .save_details ,
90
92
tensorboard_metric_prefix = lighteval_config .logging .tensorboard_metric_prefix ,
91
- nanotron_run_info = nanotron_config .nanotron_config . general ,
93
+ nanotron_run_info = nanotron_config .nanotron_general ,
92
94
)
93
95
94
96
pipeline_parameters = PipelineParameters (
95
97
launcher_type = ParallelismManager .NANOTRON ,
96
- env_config = env_config ,
97
98
job_id = os .environ .get ("SLURM_JOB_ID" , 0 ),
98
99
nanotron_checkpoint_path = checkpoint_config_path ,
99
100
dataset_loading_processes = lighteval_config .tasks .dataset_loading_processes ,
100
101
custom_tasks_directory = lighteval_config .tasks .custom_tasks ,
101
- override_batch_size = lighteval_config .batch_size ,
102
102
num_fewshot_seeds = 1 ,
103
103
max_samples = lighteval_config .tasks .max_samples ,
104
104
use_chat_template = False ,
0 commit comments