3030"""
3131
3232import argparse
33+ import os
3334import sys
3435from typing import Any , List , Tuple
3536
4546cs .store (name = "llm_config" , node = LlmConfig )
4647
4748
48- # Need this global variable to pass an llm_config from yaml
49- # into the hydra-wrapped main function.
50- llm_config_from_yaml = None
51-
52-
5349def parse_config_arg () -> Tuple [str , List [Any ]]:
5450 parser = argparse .ArgumentParser (add_help = True )
5551 parser .add_argument ("--config" , type = str , help = "Path to the LlmConfig file" )
@@ -61,27 +57,36 @@ def pop_config_arg() -> str:
6157 """
6258 Removes '--config' and its value from sys.argv.
6359 Assumes --config is specified and argparse has already validated the args.
60+ Returns the config file path.
6461 """
6562 idx = sys .argv .index ("--config" )
6663 value = sys .argv [idx + 1 ]
6764 del sys .argv [idx : idx + 2 ]
6865 return value
6966
7067
71- @hydra .main (version_base = None , config_name = "llm_config" )
72- def hydra_main (llm_config : LlmConfig ) -> None :
73- global llm_config_from_yaml
68+ def add_hydra_config_args (config_file_path : str ) -> None :
69+ """
70+ Breaks down the config file path into directory and filename,
71+ resolves the directory to an absolute path, and adds the
72+ --config_path and --config_name arguments to sys.argv.
73+ """
74+ config_dir = os .path .dirname (config_file_path )
75+ config_name = os .path .basename (config_file_path )
76+
77+ # Resolve to absolute path
78+ config_dir_abs = os .path .abspath (config_dir )
79+
80+ # Add the hydra config arguments to sys.argv
81+ sys .argv .extend (["--config-path" , config_dir_abs , "--config-name" , config_name ])
82+
7483
75- # Override the LlmConfig constructed from the provide yaml config file
76- # with the CLI overrides.
77- if llm_config_from_yaml :
78- # Get CLI overrides (excluding defaults list).
79- overrides_list : List [str ] = list (HydraConfig .get ().overrides .get ("task" , []))
80- override_cfg = OmegaConf .from_dotlist (overrides_list )
81- merged_config = OmegaConf .merge (llm_config_from_yaml , override_cfg )
82- export_llama (merged_config )
83- else :
84- export_llama (OmegaConf .to_object (llm_config ))
84+ @hydra .main (version_base = None , config_name = "llm_config" , config_path = None )
85+ def hydra_main (llm_config : LlmConfig ) -> None :
86+ structured = OmegaConf .structured (LlmConfig )
87+ merged = OmegaConf .merge (structured , llm_config )
88+ llm_config_obj = OmegaConf .to_object (merged )
89+ export_llama (llm_config_obj )
8590
8691
8792def main () -> None :
@@ -90,13 +95,12 @@ def main() -> None:
9095 if config :
9196 global llm_config_from_yaml
9297 # Pop out --config and its value so that they are not parsed by
93- # Hyra 's main.
98+ # Hydra 's main.
9499 config_file_path = pop_config_arg ()
95- default_llm_config = LlmConfig ()
96- # Construct the LlmConfig from the config yaml file.
97- default_llm_config = LlmConfig ()
98- from_yaml = OmegaConf .load (config_file_path )
99- llm_config_from_yaml = OmegaConf .merge (default_llm_config , from_yaml )
100+
101+ # Add hydra config_path and config_name arguments to sys.argv.
102+ add_hydra_config_args (config_file_path )
103+
100104 hydra_main ()
101105
102106
0 commit comments