Skip to content

Commit 26078ae

Browse files
committed
Try splitting config into path and name
1 parent c6362fb commit 26078ae

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

extension/llm/export/export_llm.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131

3232
import argparse
33+
import os
3334
import sys
3435
from typing import Any, List, Tuple
3536

@@ -45,11 +46,6 @@
4546
cs.store(name="llm_config", node=LlmConfig)
4647

4748

48-
# Need this global variable to pass an llm_config from yaml
49-
# into the hydra-wrapped main function.
50-
llm_config_from_yaml = None
51-
52-
5349
def parse_config_arg() -> Tuple[str, List[Any]]:
5450
parser = argparse.ArgumentParser(add_help=True)
5551
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
@@ -61,27 +57,36 @@ def pop_config_arg() -> str:
6157
"""
6258
Removes '--config' and its value from sys.argv.
6359
Assumes --config is specified and argparse has already validated the args.
60+
Returns the config file path.
6461
"""
6562
idx = sys.argv.index("--config")
6663
value = sys.argv[idx + 1]
6764
del sys.argv[idx : idx + 2]
6865
return value
6966

7067

71-
@hydra.main(version_base=None, config_name="llm_config")
72-
def hydra_main(llm_config: LlmConfig) -> None:
73-
global llm_config_from_yaml
68+
def add_hydra_config_args(config_file_path: str) -> None:
69+
"""
70+
Breaks down the config file path into directory and filename,
71+
resolves the directory to an absolute path, and adds the
72+
--config_path and --config_name arguments to sys.argv.
73+
"""
74+
config_dir = os.path.dirname(config_file_path)
75+
config_name = os.path.basename(config_file_path)
76+
77+
# Resolve to absolute path
78+
config_dir_abs = os.path.abspath(config_dir)
79+
80+
# Add the hydra config arguments to sys.argv
81+
sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name])
82+
7483

75-
# Override the LlmConfig constructed from the provide yaml config file
76-
# with the CLI overrides.
77-
if llm_config_from_yaml:
78-
# Get CLI overrides (excluding defaults list).
79-
overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", []))
80-
override_cfg = OmegaConf.from_dotlist(overrides_list)
81-
merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg)
82-
export_llama(merged_config)
83-
else:
84-
export_llama(OmegaConf.to_object(llm_config))
84+
@hydra.main(version_base=None, config_name="llm_config", config_path=None)
85+
def hydra_main(llm_config: LlmConfig) -> None:
86+
structured = OmegaConf.structured(LlmConfig)
87+
merged = OmegaConf.merge(structured, llm_config)
88+
llm_config_obj = OmegaConf.to_object(merged)
89+
export_llama(llm_config_obj)
8590

8691

8792
def main() -> None:
@@ -90,13 +95,12 @@ def main() -> None:
9095
if config:
9196
global llm_config_from_yaml
9297
# Pop out --config and its value so that they are not parsed by
93-
# Hyra's main.
98+
# Hydra's main.
9499
config_file_path = pop_config_arg()
95-
default_llm_config = LlmConfig()
96-
# Construct the LlmConfig from the config yaml file.
97-
default_llm_config = LlmConfig()
98-
from_yaml = OmegaConf.load(config_file_path)
99-
llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml)
100+
101+
# Add hydra config_path and config_name arguments to sys.argv.
102+
add_hydra_config_args(config_file_path)
103+
100104
hydra_main()
101105

102106

0 commit comments

Comments
 (0)