Skip to content

Commit 7f69bc0

Browse files
committed
Allow CLI overrides
1 parent 752f6a7 commit 7f69bc0

File tree

2 files changed

+75
-59
lines changed

2 files changed

+75
-59
lines changed

extension/llm/export/export_llm.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,20 @@
3838
from executorch.examples.models.llama.config.llm_config import LlmConfig
3939
from executorch.examples.models.llama.export_llama_lib import export_llama
4040
from hydra.core.config_store import ConfigStore
41+
from hydra.core.hydra_config import HydraConfig
42+
from hydra.core.override_parser.overrides_parser import OverridesParser
4143
from omegaconf import OmegaConf
4244

4345
cs = ConfigStore.instance()
4446
cs.store(name="llm_config", node=LlmConfig)
4547

4648

49+
# Need this global variable to pass an llm_config from yaml
50+
# into the hydra-wrapped main function.
51+
llm_config_from_yaml = None
52+
53+
4754
def parse_config_arg() -> Tuple[str, List[Any]]:
48-
"""First parse out the arg for whether to use Hydra or the old CLI."""
4955
parser = argparse.ArgumentParser(add_help=True)
5056
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
5157
args, remaining = parser.parse_known_args()
@@ -65,28 +71,34 @@ def pop_config_arg() -> str:
6571

6672
@hydra.main(version_base=None, config_name="llm_config")
6773
def hydra_main(llm_config: LlmConfig) -> None:
68-
export_llama(OmegaConf.to_object(llm_config))
74+
global llm_config_from_yaml
75+
76+
# Override the LlmConfig constructed from the provide yaml config file
77+
# with the CLI overrides.
78+
if llm_config_from_yaml:
79+
# Get CLI overrides (excluding defaults list).
80+
overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", []))
81+
override_cfg = OmegaConf.from_dotlist(overrides_list)
82+
merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg)
83+
export_llama(merged_config)
84+
else:
85+
export_llama(OmegaConf.to_object(llm_config))
6986

7087

7188
def main() -> None:
89+
# First parse out the arg for whether to use Hydra or the old CLI.
7290
config, remaining_args = parse_config_arg()
7391
if config:
74-
# Check if there are any remaining hydra CLI args when --config is specified
75-
# This might change in the future to allow overriding config file values
76-
if remaining_args:
77-
raise ValueError(
78-
"Cannot specify additional CLI arguments when using --config. "
79-
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
80-
)
81-
92+
global llm_config_from_yaml
93+
# Pop out --config and its value so that they are not parsed by
94+
# Hyra's main.
8295
config_file_path = pop_config_arg()
8396
default_llm_config = LlmConfig()
84-
llm_config_from_file = OmegaConf.load(config_file_path)
85-
# Override defaults with values specified in the .yaml provided by --config.
86-
merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file)
87-
export_llama(merged_llm_config)
88-
else:
89-
hydra_main()
97+
# Construct the LlmConfig from the config yaml file.
98+
default_llm_config = LlmConfig()
99+
from_yaml = OmegaConf.load(config_file_path)
100+
llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml)
101+
hydra_main()
90102

91103

92104
if __name__ == "__main__":

extension/llm/export/test/test_export_llm.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,37 @@ class TestExportLlm(unittest.TestCase):
2121
def test_parse_config_arg_with_config(self) -> None:
2222
"""Test parse_config_arg when --config is provided."""
2323
# Mock sys.argv to include --config
24-
test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"]
24+
test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"]
2525
with patch.object(sys, "argv", test_argv):
2626
config_path, remaining = parse_config_arg()
2727
self.assertEqual(config_path, "test_config.yaml")
2828
self.assertEqual(remaining, ["extra", "args"])
2929

3030
def test_parse_config_arg_without_config(self) -> None:
3131
"""Test parse_config_arg when --config is not provided."""
32-
test_argv = ["script.py", "debug.verbose=True"]
32+
test_argv = ["export_llm.py", "debug.verbose=True"]
3333
with patch.object(sys, "argv", test_argv):
3434
config_path, remaining = parse_config_arg()
3535
self.assertIsNone(config_path)
3636
self.assertEqual(remaining, ["debug.verbose=True"])
3737

3838
def test_pop_config_arg(self) -> None:
3939
"""Test pop_config_arg removes --config and its value from sys.argv."""
40-
test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"]
40+
test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"]
4141
with patch.object(sys, "argv", test_argv):
4242
config_path = pop_config_arg()
4343
self.assertEqual(config_path, "test_config.yaml")
44-
self.assertEqual(sys.argv, ["script.py", "other", "args"])
44+
self.assertEqual(sys.argv, ["export_llm.py", "other", "args"])
45+
46+
def test_with_cli_args(self) -> None:
47+
"""Test main function with only hydra CLI args."""
48+
test_argv = ["export_llm.py", "debug.verbose=True"]
49+
with patch.object(sys, "argv", test_argv):
50+
with patch(
51+
"executorch.extension.llm.export.export_llm.hydra_main"
52+
) as mock_hydra:
53+
main()
54+
mock_hydra.assert_called_once()
4555

4656
@patch("executorch.extension.llm.export.export_llm.export_llama")
4757
def test_with_config(self, mock_export_llama: MagicMock) -> None:
@@ -70,7 +80,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7080
config_file = f.name
7181

7282
try:
73-
test_argv = ["script.py", "--config", config_file]
83+
test_argv = ["export_llm.py", "--config", config_file]
7484
with patch.object(sys, "argv", test_argv):
7585
main()
7686

@@ -99,54 +109,48 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
99109
finally:
100110
os.unlink(config_file)
101111

102-
def test_with_cli_args(self) -> None:
103-
"""Test main function with only hydra CLI args."""
104-
test_argv = ["script.py", "debug.verbose=True"]
105-
with patch.object(sys, "argv", test_argv):
106-
with patch(
107-
"executorch.extension.llm.export.export_llm.hydra_main"
108-
) as mock_hydra:
109-
main()
110-
mock_hydra.assert_called_once()
111-
112-
def test_config_with_cli_args_error(self) -> None:
113-
"""Test that --config rejects additional CLI arguments to prevent mixing approaches."""
112+
@patch("executorch.extension.llm.export.export_llm.export_llama")
113+
def test_with_config(self, mock_export_llama: MagicMock) -> None:
114+
"""Test main function with --config file and no hydra args."""
114115
# Create a temporary config file
115116
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
116-
f.write("base:\n checkpoint: /path/to/checkpoint.pth")
117-
config_file = f.name
118-
119-
try:
120-
test_argv = ["script.py", "--config", config_file, "debug.verbose=True"]
121-
with patch.object(sys, "argv", test_argv):
122-
with self.assertRaises(ValueError) as cm:
123-
main()
124-
125-
error_msg = str(cm.exception)
126-
self.assertIn(
127-
"Cannot specify additional CLI arguments when using --config",
128-
error_msg,
129-
)
130-
finally:
131-
os.unlink(config_file)
132-
133-
def test_config_rejects_multiple_cli_args(self) -> None:
134-
"""Test that --config rejects multiple CLI arguments (not just single ones)."""
135-
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
136-
f.write("export:\n max_seq_length: 128")
117+
f.write(
118+
"""
119+
base:
120+
model_class: llama2
121+
model:
122+
dtype_override: fp16
123+
backend:
124+
xnnpack:
125+
enabled: False
126+
"""
127+
)
137128
config_file = f.name
138129

139130
try:
140131
test_argv = [
141-
"script.py",
132+
"export_llm.py",
142133
"--config",
143134
config_file,
144-
"debug.verbose=True",
145-
"export.output_dir=/tmp",
135+
"base.model_class=stories110m",
136+
"backend.xnnpack.enabled=True",
146137
]
147138
with patch.object(sys, "argv", test_argv):
148-
with self.assertRaises(ValueError):
149-
main()
139+
main()
140+
141+
# Verify export_llama was called with config
142+
mock_export_llama.assert_called_once()
143+
called_config = mock_export_llama.call_args[0][0]
144+
self.assertEqual(
145+
called_config["base"]["model_class"], "stories110m"
146+
) # Override from CLI.
147+
self.assertEqual(
148+
called_config["model"]["dtype_override"].value, "fp16"
149+
) # From yaml.
150+
self.assertEqual(
151+
called_config["backend"]["xnnpack"]["enabled"],
152+
True, # Override from CLI.
153+
)
150154
finally:
151155
os.unlink(config_file)
152156

0 commit comments

Comments
 (0)