Skip to content

Commit e2bcd13

Browse files
committed
Try splitting config into path and name
1 parent c6362fb commit e2bcd13

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

extension/llm/export/export_llm.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131

3232
import argparse
33+
import os
3334
import sys
3435
from typing import Any, List, Tuple
3536

@@ -61,27 +62,47 @@ def pop_config_arg() -> str:
6162
"""
6263
Removes '--config' and its value from sys.argv.
6364
Assumes --config is specified and argparse has already validated the args.
65+
Returns the config file path.
6466
"""
6567
idx = sys.argv.index("--config")
6668
value = sys.argv[idx + 1]
6769
del sys.argv[idx : idx + 2]
6870
return value
6971

7072

71-
@hydra.main(version_base=None, config_name="llm_config")
73+
def add_hydra_config_args(config_file_path: str) -> None:
74+
"""
75+
Breaks down the config file path into directory and filename,
76+
resolves the directory to an absolute path, and adds the
77+
--config_path and --config_name arguments to sys.argv.
78+
"""
79+
config_dir = os.path.dirname(config_file_path)
80+
config_name = os.path.basename(config_file_path)
81+
82+
# Resolve to absolute path
83+
config_dir_abs = os.path.abspath(config_dir)
84+
85+
# Add the hydra config arguments to sys.argv
86+
sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name])
87+
88+
89+
@hydra.main(version_base=None, config_name="llm_config", config_path=None)
7290
def hydra_main(llm_config: LlmConfig) -> None:
73-
global llm_config_from_yaml
91+
# global llm_config_from_yaml
92+
93+
# # Override the LlmConfig constructed from the provide yaml config file
94+
# # with the CLI overrides.
95+
# if llm_config_from_yaml:
96+
# # Get CLI overrides (excluding defaults list).
97+
# overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", []))
98+
# override_cfg = OmegaConf.from_dotlist(overrides_list)
99+
# merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg)
100+
# export_llama(merged_config)
101+
# else:
102+
# export_llama(OmegaConf.to_object(llm_config))
74103

75-
# Override the LlmConfig constructed from the provide yaml config file
76-
# with the CLI overrides.
77-
if llm_config_from_yaml:
78-
# Get CLI overrides (excluding defaults list).
79-
overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", []))
80-
override_cfg = OmegaConf.from_dotlist(overrides_list)
81-
merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg)
82-
export_llama(merged_config)
83-
else:
84-
export_llama(OmegaConf.to_object(llm_config))
104+
breakpoint()
105+
export_llama(OmegaConf.to_object(llm_config))
85106

86107

87108
def main() -> None:
@@ -90,13 +111,17 @@ def main() -> None:
90111
if config:
91112
global llm_config_from_yaml
92113
# Pop out --config and its value so that they are not parsed by
93-
# Hyra's main.
114+
# Hydra's main.
94115
config_file_path = pop_config_arg()
95-
default_llm_config = LlmConfig()
96-
# Construct the LlmConfig from the config yaml file.
97-
default_llm_config = LlmConfig()
98-
from_yaml = OmegaConf.load(config_file_path)
99-
llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml)
116+
117+
# Add hydra config_path and config_name arguments to sys.argv.
118+
add_hydra_config_args(config_file_path)
119+
120+
# # Construct the LlmConfig from the config yaml file.
121+
# default_llm_config = LlmConfig()
122+
# from_yaml = OmegaConf.load(config_file_path)
123+
# llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml)
124+
100125
hydra_main()
101126

102127

0 commit comments

Comments
 (0)