Skip to content

Commit 99bc44d

Browse files
jackzhxnghinriksnaer
authored andcommitted
Allow Hydra CLI to override yaml config (pytorch#11926)
Custom workaround to allow combining `--config` and Hydra CLI options. Something like this would be possible: ``` python -m extension.llm.export.export_llm --config llama_xnnpack.yaml export.max_seq_length=1024 backend.xnnpack.extended_ops=True ```
1 parent 059752e commit 99bc44d

File tree

3 files changed

+101
-120
lines changed

3 files changed

+101
-120
lines changed

extension/llm/export/README.md

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz
2323

2424
## Usage
2525

26-
The export API supports two configuration approaches:
26+
The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args.
2727

28-
### Option 1: Hydra CLI Arguments
28+
### Hydra CLI Arguments
2929

3030
Use structured configuration arguments directly on the command line:
3131

@@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \
4141
quantization.qmode=8da4w
4242
```
4343

44-
### Option 2: Configuration File
44+
### Configuration File
4545

4646
Create a YAML configuration file and reference it:
4747

@@ -78,53 +78,21 @@ debug:
7878
verbose: true
7979
```
8080
81-
**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both.
81+
You can you also still provide additional overrides using the CLI args as well:
8282
83-
## Example Commands
84-
85-
### Export Qwen3 0.6B with XNNPACK backend and quantization
8683
```bash
87-
python -m extension.llm.export.export_llm \
88-
base.model_class=qwen3_0_6b \
89-
base.params=examples/models/qwen3/0_6b_config.json \
90-
base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \
91-
model.use_kv_cache=true \
92-
model.use_sdpa_with_kv_cache=true \
93-
model.dtype_override=FP32 \
94-
export.max_seq_length=512 \
95-
export.output_name=qwen3_0_6b.pte \
96-
quantization.qmode=8da4w \
97-
backend.xnnpack.enabled=true \
98-
backend.xnnpack.extended_ops=true \
99-
debug.verbose=true
84+
python -m extension.llm.export.export_llm
85+
--config my_config.yaml
86+
base.model_class="llama2"
87+
+export.max_context_length=1024
10088
```
10189

102-
### Export Phi-4-Mini with custom checkpoint
103-
```bash
104-
python -m extension.llm.export.export_llm \
105-
base.model_class=phi_4_mini \
106-
base.checkpoint=/path/to/phi4_checkpoint.pth \
107-
base.params=examples/models/phi-4-mini/config.json \
108-
base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \
109-
model.use_kv_cache=true \
110-
model.use_sdpa_with_kv_cache=true \
111-
export.max_seq_length=256 \
112-
export.output_name=phi4_mini.pte \
113-
backend.xnnpack.enabled=true \
114-
debug.verbose=true
115-
```
90+
Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/).
11691

117-
### Export with CoreML backend (iOS optimization)
118-
```bash
119-
python -m extension.llm.export.export_llm \
120-
base.model_class=llama3 \
121-
model.use_kv_cache=true \
122-
export.max_seq_length=128 \
123-
backend.coreml.enabled=true \
124-
backend.coreml.compute_units=ALL \
125-
quantization.pt2e_quantize=coreml_c4w \
126-
debug.verbose=true
127-
```
92+
93+
## Example Commands
94+
95+
Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)).
12896

12997
## Configuration Options
13098

@@ -134,4 +102,4 @@ For a complete reference of all available configuration options, see the [LlmCon
134102

135103
- [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide
136104
- [LLM Runner](../runner/) - Running exported models
137-
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview
105+
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview

extension/llm/export/export_llm.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131

3232
import argparse
33+
import os
3334
import sys
3435
from typing import Any, List, Tuple
3536

@@ -45,7 +46,6 @@
4546

4647

4748
def parse_config_arg() -> Tuple[str, List[Any]]:
48-
"""First parse out the arg for whether to use Hydra or the old CLI."""
4949
parser = argparse.ArgumentParser(add_help=True)
5050
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
5151
args, remaining = parser.parse_known_args()
@@ -56,37 +56,50 @@ def pop_config_arg() -> str:
5656
"""
5757
Removes '--config' and its value from sys.argv.
5858
Assumes --config is specified and argparse has already validated the args.
59+
Returns the config file path.
5960
"""
6061
idx = sys.argv.index("--config")
6162
value = sys.argv[idx + 1]
6263
del sys.argv[idx : idx + 2]
6364
return value
6465

6566

66-
@hydra.main(version_base=None, config_name="llm_config")
67+
def add_hydra_config_args(config_file_path: str) -> None:
68+
"""
69+
Breaks down the config file path into directory and filename,
70+
resolves the directory to an absolute path, and adds the
71+
--config_path and --config_name arguments to sys.argv.
72+
"""
73+
config_dir = os.path.dirname(config_file_path)
74+
config_name = os.path.basename(config_file_path)
75+
76+
# Resolve to absolute path
77+
config_dir_abs = os.path.abspath(config_dir)
78+
79+
# Add the hydra config arguments to sys.argv
80+
sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name])
81+
82+
83+
@hydra.main(version_base=None, config_name="llm_config", config_path=None)
6784
def hydra_main(llm_config: LlmConfig) -> None:
68-
export_llama(OmegaConf.to_object(llm_config))
85+
structured = OmegaConf.structured(LlmConfig)
86+
merged = OmegaConf.merge(structured, llm_config)
87+
llm_config_obj = OmegaConf.to_object(merged)
88+
export_llama(llm_config_obj)
6989

7090

7191
def main() -> None:
92+
# First parse out the arg for whether to use Hydra or the old CLI.
7293
config, remaining_args = parse_config_arg()
7394
if config:
74-
# Check if there are any remaining hydra CLI args when --config is specified
75-
# This might change in the future to allow overriding config file values
76-
if remaining_args:
77-
raise ValueError(
78-
"Cannot specify additional CLI arguments when using --config. "
79-
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
80-
)
81-
95+
# Pop out --config and its value so that they are not parsed by
96+
# Hydra's main.
8297
config_file_path = pop_config_arg()
83-
default_llm_config = LlmConfig()
84-
llm_config_from_file = OmegaConf.load(config_file_path)
85-
# Override defaults with values specified in the .yaml provided by --config.
86-
merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file)
87-
export_llama(merged_llm_config)
88-
else:
89-
hydra_main()
98+
99+
# Add hydra config_path and config_name arguments to sys.argv.
100+
add_hydra_config_args(config_file_path)
101+
102+
hydra_main()
90103

91104

92105
if __name__ == "__main__":

extension/llm/export/test/test_export_llm.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,37 @@ class TestExportLlm(unittest.TestCase):
2121
def test_parse_config_arg_with_config(self) -> None:
2222
"""Test parse_config_arg when --config is provided."""
2323
# Mock sys.argv to include --config
24-
test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"]
24+
test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"]
2525
with patch.object(sys, "argv", test_argv):
2626
config_path, remaining = parse_config_arg()
2727
self.assertEqual(config_path, "test_config.yaml")
2828
self.assertEqual(remaining, ["extra", "args"])
2929

3030
def test_parse_config_arg_without_config(self) -> None:
3131
"""Test parse_config_arg when --config is not provided."""
32-
test_argv = ["script.py", "debug.verbose=True"]
32+
test_argv = ["export_llm.py", "debug.verbose=True"]
3333
with patch.object(sys, "argv", test_argv):
3434
config_path, remaining = parse_config_arg()
3535
self.assertIsNone(config_path)
3636
self.assertEqual(remaining, ["debug.verbose=True"])
3737

3838
def test_pop_config_arg(self) -> None:
3939
"""Test pop_config_arg removes --config and its value from sys.argv."""
40-
test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"]
40+
test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"]
4141
with patch.object(sys, "argv", test_argv):
4242
config_path = pop_config_arg()
4343
self.assertEqual(config_path, "test_config.yaml")
44-
self.assertEqual(sys.argv, ["script.py", "other", "args"])
44+
self.assertEqual(sys.argv, ["export_llm.py", "other", "args"])
45+
46+
def test_with_cli_args(self) -> None:
47+
"""Test main function with only hydra CLI args."""
48+
test_argv = ["export_llm.py", "debug.verbose=True"]
49+
with patch.object(sys, "argv", test_argv):
50+
with patch(
51+
"executorch.extension.llm.export.export_llm.hydra_main"
52+
) as mock_hydra:
53+
main()
54+
mock_hydra.assert_called_once()
4555

4656
@patch("executorch.extension.llm.export.export_llm.export_llama")
4757
def test_with_config(self, mock_export_llama: MagicMock) -> None:
@@ -70,83 +80,73 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7080
config_file = f.name
7181

7282
try:
73-
test_argv = ["script.py", "--config", config_file]
83+
test_argv = ["export_llm.py", "--config", config_file]
7484
with patch.object(sys, "argv", test_argv):
7585
main()
7686

7787
# Verify export_llama was called with config
7888
mock_export_llama.assert_called_once()
7989
called_config = mock_export_llama.call_args[0][0]
8090
self.assertEqual(
81-
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
82-
)
83-
self.assertEqual(called_config["base"]["model_class"], "llama2")
84-
self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w")
85-
self.assertEqual(called_config["model"]["dtype_override"].value, "fp16")
86-
self.assertEqual(called_config["export"]["max_seq_length"], 256)
87-
self.assertEqual(
88-
called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic"
89-
)
90-
self.assertEqual(
91-
called_config["quantization"]["use_spin_quant"].value, "cuda"
91+
called_config.base.tokenizer_path, "/path/to/tokenizer.json"
9292
)
93+
self.assertEqual(called_config.base.model_class, "llama2")
94+
self.assertEqual(called_config.base.preq_mode.value, "8da4w")
95+
self.assertEqual(called_config.model.dtype_override.value, "fp16")
96+
self.assertEqual(called_config.export.max_seq_length, 256)
9397
self.assertEqual(
94-
called_config["backend"]["coreml"]["quantize"].value, "c4w"
98+
called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic"
9599
)
100+
self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda")
101+
self.assertEqual(called_config.backend.coreml.quantize.value, "c4w")
96102
self.assertEqual(
97-
called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu"
103+
called_config.backend.coreml.compute_units.value, "cpu_and_gpu"
98104
)
99105
finally:
100106
os.unlink(config_file)
101107

102-
def test_with_cli_args(self) -> None:
103-
"""Test main function with only hydra CLI args."""
104-
test_argv = ["script.py", "debug.verbose=True"]
105-
with patch.object(sys, "argv", test_argv):
106-
with patch(
107-
"executorch.extension.llm.export.export_llm.hydra_main"
108-
) as mock_hydra:
109-
main()
110-
mock_hydra.assert_called_once()
111-
112-
def test_config_with_cli_args_error(self) -> None:
113-
"""Test that --config rejects additional CLI arguments to prevent mixing approaches."""
108+
@patch("executorch.extension.llm.export.export_llm.export_llama")
109+
def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None:
110+
"""Test main function with --config file and no hydra args."""
114111
# Create a temporary config file
115112
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
116-
f.write("base:\n checkpoint: /path/to/checkpoint.pth")
117-
config_file = f.name
118-
119-
try:
120-
test_argv = ["script.py", "--config", config_file, "debug.verbose=True"]
121-
with patch.object(sys, "argv", test_argv):
122-
with self.assertRaises(ValueError) as cm:
123-
main()
124-
125-
error_msg = str(cm.exception)
126-
self.assertIn(
127-
"Cannot specify additional CLI arguments when using --config",
128-
error_msg,
129-
)
130-
finally:
131-
os.unlink(config_file)
132-
133-
def test_config_rejects_multiple_cli_args(self) -> None:
134-
"""Test that --config rejects multiple CLI arguments (not just single ones)."""
135-
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
136-
f.write("export:\n max_seq_length: 128")
113+
f.write(
114+
"""
115+
base:
116+
model_class: llama2
117+
model:
118+
dtype_override: fp16
119+
backend:
120+
xnnpack:
121+
enabled: False
122+
"""
123+
)
137124
config_file = f.name
138125

139126
try:
140127
test_argv = [
141-
"script.py",
128+
"export_llm.py",
142129
"--config",
143130
config_file,
144-
"debug.verbose=True",
145-
"export.output_dir=/tmp",
131+
"base.model_class=stories110m",
132+
"backend.xnnpack.enabled=True",
146133
]
147134
with patch.object(sys, "argv", test_argv):
148-
with self.assertRaises(ValueError):
149-
main()
135+
main()
136+
137+
# Verify export_llama was called with config
138+
mock_export_llama.assert_called_once()
139+
called_config = mock_export_llama.call_args[0][0]
140+
self.assertEqual(
141+
called_config.base.model_class, "stories110m"
142+
) # Override from CLI.
143+
self.assertEqual(
144+
called_config.model.dtype_override.value, "fp16"
145+
) # From yaml.
146+
self.assertEqual(
147+
called_config.backend.xnnpack.enabled,
148+
True, # Override from CLI.
149+
)
150150
finally:
151151
os.unlink(config_file)
152152

0 commit comments

Comments
 (0)