2323# flake8: noqa: C901
2424import os
2525
26+ import yaml
2627from typer import Option
2728from typing_extensions import Annotated
29+ from yaml import SafeLoader
2830
2931
30- CACHE_DIR : str = os .getenv ("HF_HOME" , "/scratch" )
31-
3232HELP_PANEL_NAME_1 = "Common Parameters"
3333HELP_PANEL_NAME_2 = "Logging Parameters"
3434HELP_PANEL_NAME_3 = "Debug Parameters"
@@ -43,41 +43,43 @@ def nanotron(
4343 str , Option (help = "Path to the nanotron checkpoint YAML or python config file, potentially on s3." )
4444 ],
4545 lighteval_config_path : Annotated [str , Option (help = "Path to a YAML config to be used for the evaluation." )],
46- cache_dir : Annotated [str , Option (help = "Cache directory for datasets and models." )] = CACHE_DIR ,
4746):
4847 """
4948 Evaluate models using nanotron as backend.
5049 """
51- from nanotron .config import Config , get_config_from_file
50+ from nanotron .config import GeneralArgs , ModelArgs , TokenizerArgs , get_config_from_dict , get_config_from_file
5251
53- from lighteval .config .lighteval_config import FullNanotronConfig , LightEvalConfig
52+ from lighteval .config .lighteval_config import (
53+ FullNanotronConfig ,
54+ LightEvalConfig ,
55+ )
5456 from lighteval .logging .evaluation_tracker import EvaluationTracker
55- from lighteval .logging .hierarchical_logger import htrack_block
5657 from lighteval .pipeline import ParallelismManager , Pipeline , PipelineParameters
5758 from lighteval .utils .imports import NO_NANOTRON_ERROR_MSG , is_nanotron_available
58- from lighteval .utils .utils import EnvConfig
59-
60- env_config = EnvConfig (token = os .getenv ("HF_TOKEN" ), cache_dir = cache_dir )
6159
6260 if not is_nanotron_available ():
6361 raise ImportError (NO_NANOTRON_ERROR_MSG )
6462
65- with htrack_block ("Load nanotron config" ):
66- # Create nanotron config
67- if not checkpoint_config_path .endswith (".yaml" ):
68- raise ValueError ("The checkpoint path should point to a YAML file" )
63+ # Create nanotron config
64+ if not checkpoint_config_path .endswith (".yaml" ):
65+ raise ValueError ("The checkpoint path should point to a YAML file" )
66+
67+ with open (checkpoint_config_path ) as f :
68+ nanotron_yaml = yaml .load (f , Loader = SafeLoader )
6969
70- model_config = get_config_from_file (
71- checkpoint_config_path ,
72- config_class = Config ,
73- model_config_class = None ,
70+ model_config , tokenizer_config , general_config = [
71+ get_config_from_dict (
72+ nanotron_yaml [ key ] ,
73+ config_class = config_class ,
7474 skip_unused_config_keys = True ,
7575 skip_null_keys = True ,
7676 )
77+ for key , config_class in [("model" , ModelArgs ), ("tokenizer" , TokenizerArgs ), ("general" , GeneralArgs )]
78+ ]
7779
78- # We are getting an type error, because the get_config_from_file is not correctly typed,
79- lighteval_config : LightEvalConfig = get_config_from_file (lighteval_config_path , config_class = LightEvalConfig ) # type: ignore
80- nanotron_config = FullNanotronConfig (lighteval_config , model_config )
80+ # Load lighteval config
81+ lighteval_config : LightEvalConfig = get_config_from_file (lighteval_config_path , config_class = LightEvalConfig ) # type: ignore
82+ nanotron_config = FullNanotronConfig (lighteval_config , model_config , tokenizer_config , general_config )
8183
8284 evaluation_tracker = EvaluationTracker (
8385 output_dir = lighteval_config .logging .output_dir ,
@@ -88,17 +90,15 @@ def nanotron(
8890 push_to_tensorboard = lighteval_config .logging .push_to_tensorboard ,
8991 save_details = lighteval_config .logging .save_details ,
9092 tensorboard_metric_prefix = lighteval_config .logging .tensorboard_metric_prefix ,
91- nanotron_run_info = nanotron_config .nanotron_config . general ,
93+ nanotron_run_info = nanotron_config .nanotron_general ,
9294 )
9395
9496 pipeline_parameters = PipelineParameters (
9597 launcher_type = ParallelismManager .NANOTRON ,
96- env_config = env_config ,
9798 job_id = os .environ .get ("SLURM_JOB_ID" , 0 ),
9899 nanotron_checkpoint_path = checkpoint_config_path ,
99100 dataset_loading_processes = lighteval_config .tasks .dataset_loading_processes ,
100101 custom_tasks_directory = lighteval_config .tasks .custom_tasks ,
101- override_batch_size = lighteval_config .batch_size ,
102102 num_fewshot_seeds = 1 ,
103103 max_samples = lighteval_config .tasks .max_samples ,
104104 use_chat_template = False ,
0 commit comments