3030"""
3131
3232import argparse
33+ import os
3334import sys
3435from typing import Any , List , Tuple
3536
@@ -61,27 +62,47 @@ def pop_config_arg() -> str:
6162 """
6263 Removes '--config' and its value from sys.argv.
6364 Assumes --config is specified and argparse has already validated the args.
65+ Returns the config file path.
6466 """
6567 idx = sys .argv .index ("--config" )
6668 value = sys .argv [idx + 1 ]
6769 del sys .argv [idx : idx + 2 ]
6870 return value
6971
7072
71- @hydra .main (version_base = None , config_name = "llm_config" )
73+ def add_hydra_config_args (config_file_path : str ) -> None :
74+ """
75+ Breaks down the config file path into directory and filename,
76+ resolves the directory to an absolute path, and adds the
77+ --config_path and --config_name arguments to sys.argv.
78+ """
79+ config_dir = os .path .dirname (config_file_path )
80+ config_name = os .path .basename (config_file_path )
81+
82+ # Resolve to absolute path
83+ config_dir_abs = os .path .abspath (config_dir )
84+
85+ # Add the hydra config arguments to sys.argv
86+ sys .argv .extend (["--config-path" , config_dir_abs , "--config-name" , config_name ])
87+
88+
89+ @hydra .main (version_base = None , config_name = "llm_config" , config_path = None )
7290def hydra_main (llm_config : LlmConfig ) -> None :
73- global llm_config_from_yaml
91+ # global llm_config_from_yaml
92+
93+ # # Override the LlmConfig constructed from the provide yaml config file
94+ # # with the CLI overrides.
95+ # if llm_config_from_yaml:
96+ # # Get CLI overrides (excluding defaults list).
97+ # overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", []))
98+ # override_cfg = OmegaConf.from_dotlist(overrides_list)
99+ # merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg)
100+ # export_llama(merged_config)
101+ # else:
102+ # export_llama(OmegaConf.to_object(llm_config))
74103
75- # Override the LlmConfig constructed from the provide yaml config file
76- # with the CLI overrides.
77- if llm_config_from_yaml :
78- # Get CLI overrides (excluding defaults list).
79- overrides_list : List [str ] = list (HydraConfig .get ().overrides .get ("task" , []))
80- override_cfg = OmegaConf .from_dotlist (overrides_list )
81- merged_config = OmegaConf .merge (llm_config_from_yaml , override_cfg )
82- export_llama (merged_config )
83- else :
84- export_llama (OmegaConf .to_object (llm_config ))
104+ breakpoint ()
105+ export_llama (OmegaConf .to_object (llm_config ))
85106
86107
87108def main () -> None :
@@ -90,13 +111,17 @@ def main() -> None:
90111 if config :
91112 global llm_config_from_yaml
92113 # Pop out --config and its value so that they are not parsed by
93- # Hyra 's main.
114+ # Hydra 's main.
94115 config_file_path = pop_config_arg ()
95- default_llm_config = LlmConfig ()
96- # Construct the LlmConfig from the config yaml file.
97- default_llm_config = LlmConfig ()
98- from_yaml = OmegaConf .load (config_file_path )
99- llm_config_from_yaml = OmegaConf .merge (default_llm_config , from_yaml )
116+
117+ # Add hydra config_path and config_name arguments to sys.argv.
118+ add_hydra_config_args (config_file_path )
119+
120+ # # Construct the LlmConfig from the config yaml file.
121+ # default_llm_config = LlmConfig()
122+ # from_yaml = OmegaConf.load(config_file_path)
123+ # llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml)
124+
100125 hydra_main ()
101126
102127
0 commit comments