Skip to content

Commit a147d30

Browse files
committed
Ability to specify full file configs from anywhere
ghstack-source-id: ceeb670 ghstack-comment-id: 2986530890 Pull-Request: pytorch/executorch#11809
1 parent a1dec07 commit a147d30

File tree

6 files changed

+327
-6
lines changed

6 files changed

+327
-6
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class BaseConfig:
6565
params: Model parameters, such as n_layers, hidden_size, etc.
6666
If left empty will use defaults specified in model_args.py.
6767
checkpoint: Path to the checkpoint file.
68-
If left empty, the model will be initialized with random weights.
68+
If left empty, the model will either be initialized with random weights
69+
if it is a Llama model or the weights will be downloaded from HuggingFace
70+
if it is a non-Llama model.
6971
checkpoint_dir: Path to directory containing sharded checkpoint files.
7072
tokenizer_path: Path to the tokenizer file.
7173
metadata: Json string containing metadata information.

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
)
5454
from executorch.util.activation_memory_profiler import generate_memory_trace
5555

56+
from omegaconf import DictConfig
57+
5658
from ..model_factory import EagerModelFactory
5759
from .source_transformation.apply_spin_quant_r1_r2 import (
5860
fuse_layer_norms,
@@ -571,12 +573,14 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
571573

572574

573575
def export_llama(
574-
export_options: Union[argparse.Namespace, LlmConfig],
576+
export_options: Union[argparse.Namespace, LlmConfig, DictConfig],
575577
) -> str:
576578
if isinstance(export_options, argparse.Namespace):
577579
# Legacy CLI.
578580
llm_config = LlmConfig.from_args(export_options)
579-
elif isinstance(export_options, LlmConfig):
581+
elif isinstance(export_options, LlmConfig) or isinstance(
582+
export_options, DictConfig
583+
):
580584
# Hydra CLI.
581585
llm_config = export_options
582586
else:

extension/llm/export/README.md

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# LLM Export API
2+
3+
This directory contains the unified API for exporting Large Language Models (LLMs) to ExecuTorch. The `export_llm` module provides a streamlined interface to convert various LLM architectures to optimized `.pte` files for on-device inference.
4+
5+
## Overview
6+
7+
The LLM export process transforms a model from its original format to an optimized representation suitable for mobile and edge devices. This involves several key steps:
8+
9+
1. **Model Instantiation**: Load the model architecture and weights from sources like Hugging Face
10+
2. **Source Transformations**: Apply model-specific optimizations and quantization
11+
3. **IR Export**: Convert to intermediate representations (EXIR, Edge dialect)
12+
4. **Graph Transformations**: Apply backend-specific optimizations and PT2E quantization
13+
5. **Backend Delegation**: Partition operations to hardware-specific backends (XNNPACK, CoreML, QNN, etc.)
14+
6. **Serialization**: Export to final ExecuTorch `.pte` format
15+
16+
## Supported Models
17+
18+
- **Llama**: Llama 2, Llama 3, Llama 3.1, Llama 3.2 (1B, 3B, 8B variants)
19+
- **Qwen**: Qwen 2.5, Qwen 3 (0.6B, 1.7B, 4B variants)
20+
- **Phi**: Phi-3-Mini, Phi-4-Mini
21+
- **Stories**: Stories110M (educational model)
22+
- **SmolLM**: SmolLM2
23+
24+
## Installation
25+
26+
First, install the required dependencies:
27+
28+
```bash
29+
./extension/llm/install_requirements.sh
30+
```
31+
32+
## Usage
33+
34+
The export API supports two configuration approaches:
35+
36+
### Option 1: Hydra CLI Arguments
37+
38+
Use structured configuration arguments directly on the command line:
39+
40+
```bash
41+
python -m extension.llm.export.export_llm \
42+
base.model_class=llama3 \
43+
model.use_sdpa_with_kv_cache=True \
44+
model.use_kv_cache=True \
45+
export.max_seq_length=128 \
46+
debug.verbose=True \
47+
backend.xnnpack.enabled=True \
48+
backend.xnnpack.extended_ops=True \
49+
quantization.qmode=8da4w
50+
```
51+
52+
### Option 2: Configuration File
53+
54+
Create a YAML configuration file and reference it:
55+
56+
```bash
57+
python -m extension.llm.export.export_llm --config my_config.yaml
58+
```
59+
60+
Example `my_config.yaml`:
61+
```yaml
62+
base:
63+
model_class: llama3
64+
tokenizer_path: /path/to/tokenizer.json
65+
66+
model:
67+
use_kv_cache: true
68+
use_sdpa_with_kv_cache: true
69+
enable_dynamic_shape: true
70+
71+
export:
72+
max_seq_length: 512
73+
output_dir: ./exported_models
74+
output_name: llama3_optimized.pte
75+
76+
quantization:
77+
qmode: 8da4w
78+
group_size: 32
79+
80+
backend:
81+
xnnpack:
82+
enabled: true
83+
extended_ops: true
84+
85+
debug:
86+
verbose: true
87+
```
88+
89+
**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both.
90+
91+
## Example Commands
92+
93+
### Export Qwen3 0.6B with XNNPACK backend and quantization
94+
```bash
95+
python -m extension.llm.export.export_llm \
96+
base.model_class=qwen3-0_6b \
97+
base.params=examples/models/qwen3/0_6b_config.json \
98+
base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \
99+
model.use_kv_cache=true \
100+
model.use_sdpa_with_kv_cache=true \
101+
model.dtype_override=FP32 \
102+
export.max_seq_length=512 \
103+
export.output_name=qwen3_0_6b.pte \
104+
quantization.qmode=8da4w \
105+
backend.xnnpack.enabled=true \
106+
backend.xnnpack.extended_ops=true \
107+
debug.verbose=true
108+
```
109+
110+
### Export Phi-4-Mini with custom checkpoint
111+
```bash
112+
python -m extension.llm.export.export_llm \
113+
base.model_class=phi_4_mini \
114+
base.checkpoint=/path/to/phi4_checkpoint.pth \
115+
base.params=examples/models/phi-4-mini/config.json \
116+
base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \
117+
model.use_kv_cache=true \
118+
model.use_sdpa_with_kv_cache=true \
119+
export.max_seq_length=256 \
120+
export.output_name=phi4_mini.pte \
121+
backend.xnnpack.enabled=true \
122+
debug.verbose=true
123+
```
124+
125+
### Export with CoreML backend (iOS optimization)
126+
```bash
127+
python -m extension.llm.export.export_llm \
128+
base.model_class=llama3 \
129+
model.use_kv_cache=true \
130+
export.max_seq_length=128 \
131+
backend.coreml.enabled=true \
132+
backend.coreml.compute_units=ALL \
133+
quantization.pt2e_quantize=coreml_c4w \
134+
debug.verbose=true
135+
```
136+
137+
## Configuration Options
138+
139+
For a complete reference of all available configuration options, see the [LlmConfig class definition](../../../examples/models/llama/config/llm_config.py) which documents all supported parameters for base, model, export, quantization, backend, and debug configurations.
140+
141+
## Further Reading
142+
143+
- [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide
144+
- [LLM Runner](../runner/) - Running exported models
145+
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview

extension/llm/export/export_llm.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,72 @@
2323
backend.xnnpack.enabled=True \
2424
backend.xnnpack.extended_ops=True \
2525
quantization.qmode="8da4w"
26+
27+
Example usage using config file:
28+
python -m extension.llm.export.export_llm \
29+
--config example_llm_config.yaml
2630
"""
2731

32+
import argparse
33+
import sys
34+
from typing import Any, List, Tuple
35+
2836
import hydra
37+
import yaml
2938

3039
from executorch.examples.models.llama.config.llm_config import LlmConfig
3140
from executorch.examples.models.llama.export_llama_lib import export_llama
3241
from hydra.core.config_store import ConfigStore
33-
from omegaconf import OmegaConf
42+
from omegaconf import DictConfig, OmegaConf
3443

3544
cs = ConfigStore.instance()
3645
cs.store(name="llm_config", node=LlmConfig)
3746

3847

39-
@hydra.main(version_base=None, config_path=None, config_name="llm_config")
40-
def main(llm_config: LlmConfig) -> None:
48+
def parse_config_arg() -> Tuple[str, List[Any]]:
49+
"""First parse out the arg for whether to use Hydra or the old CLI."""
50+
parser = argparse.ArgumentParser(add_help=True)
51+
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
52+
args, remaining = parser.parse_known_args()
53+
return args.config, remaining
54+
55+
56+
def pop_config_arg() -> str:
57+
"""
58+
Removes '--config' and its value from sys.argv.
59+
Assumes --config is specified and argparse has already validated the args.
60+
"""
61+
idx = sys.argv.index("--config")
62+
value = sys.argv[idx + 1]
63+
del sys.argv[idx : idx + 2]
64+
return value
65+
66+
67+
@hydra.main(version_base=None, config_name="llm_config")
68+
def hydra_main(llm_config: LlmConfig) -> None:
4169
export_llama(OmegaConf.to_object(llm_config))
4270

4371

72+
def main() -> None:
73+
config, remaining_args = parse_config_arg()
74+
if config:
75+
# Check if there are any remaining hydra CLI args when --config is specified
76+
# This might change in the future to allow overriding config file values
77+
if remaining_args:
78+
raise ValueError(
79+
"Cannot specify additional CLI arguments when using --config. "
80+
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
81+
)
82+
83+
config_file_path = pop_config_arg()
84+
default_llm_config = LlmConfig()
85+
llm_config_from_file = OmegaConf.load(config_file_path)
86+
# Override defaults with values specified in the .yaml provided by --config.
87+
merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file)
88+
export_llama(merged_llm_config)
89+
else:
90+
hydra_main()
91+
92+
4493
if __name__ == "__main__":
4594
main()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import sys
9+
import tempfile
10+
import unittest
11+
from unittest.mock import MagicMock, patch
12+
13+
from executorch.examples.models.llama.config.llm_config import LlmConfig
14+
from executorch.extension.llm.export.export_llm import main, parse_config_arg, pop_config_arg
15+
16+
17+
class TestExportLlm(unittest.TestCase):
18+
def test_parse_config_arg_with_config(self) -> None:
19+
"""Test parse_config_arg when --config is provided."""
20+
# Mock sys.argv to include --config
21+
test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"]
22+
with patch.object(sys, "argv", test_argv):
23+
config_path, remaining = parse_config_arg()
24+
self.assertEqual(config_path, "test_config.yaml")
25+
self.assertEqual(remaining, ["extra", "args"])
26+
27+
def test_parse_config_arg_without_config(self) -> None:
28+
"""Test parse_config_arg when --config is not provided."""
29+
test_argv = ["script.py", "debug.verbose=True"]
30+
with patch.object(sys, "argv", test_argv):
31+
config_path, remaining = parse_config_arg()
32+
self.assertIsNone(config_path)
33+
self.assertEqual(remaining, ["debug.verbose=True"])
34+
35+
def test_pop_config_arg(self) -> None:
36+
"""Test pop_config_arg removes --config and its value from sys.argv."""
37+
test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"]
38+
with patch.object(sys, "argv", test_argv):
39+
config_path = pop_config_arg()
40+
self.assertEqual(config_path, "test_config.yaml")
41+
self.assertEqual(sys.argv, ["script.py", "other", "args"])
42+
43+
@patch("executorch.extension.llm.export.export_llm.export_llama")
44+
def test_with_config(self, mock_export_llama: MagicMock) -> None:
45+
"""Test main function with --config file and no hydra args."""
46+
# Create a temporary config file
47+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
48+
f.write("""
49+
base:
50+
tokenizer_path: /path/to/tokenizer.json
51+
export:
52+
max_seq_length: 256
53+
""")
54+
config_file = f.name
55+
56+
try:
57+
test_argv = ["script.py", "--config", config_file]
58+
with patch.object(sys, "argv", test_argv):
59+
main()
60+
61+
# Verify export_llama was called with config
62+
mock_export_llama.assert_called_once()
63+
called_config = mock_export_llama.call_args[0][0]
64+
self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json")
65+
self.assertEqual(called_config["export"]["max_seq_length"], 256)
66+
finally:
67+
os.unlink(config_file)
68+
69+
def test_with_cli_args(self) -> None:
70+
"""Test main function with only hydra CLI args."""
71+
test_argv = ["script.py", "debug.verbose=True"]
72+
with patch.object(sys, "argv", test_argv):
73+
with patch("executorch.extension.llm.export.export_llm.hydra_main") as mock_hydra:
74+
main()
75+
mock_hydra.assert_called_once()
76+
77+
def test_config_with_cli_args_error(self) -> None:
78+
"""Test that --config rejects additional CLI arguments to prevent mixing approaches."""
79+
# Create a temporary config file
80+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
81+
f.write("base:\n checkpoint: /path/to/checkpoint.pth")
82+
config_file = f.name
83+
84+
try:
85+
test_argv = ["script.py", "--config", config_file, "debug.verbose=True"]
86+
with patch.object(sys, "argv", test_argv):
87+
with self.assertRaises(ValueError) as cm:
88+
main()
89+
90+
error_msg = str(cm.exception)
91+
self.assertIn("Cannot specify additional CLI arguments when using --config", error_msg)
92+
finally:
93+
os.unlink(config_file)
94+
95+
def test_config_rejects_multiple_cli_args(self) -> None:
96+
"""Test that --config rejects multiple CLI arguments (not just single ones)."""
97+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
98+
f.write("export:\n max_seq_length: 128")
99+
config_file = f.name
100+
101+
try:
102+
test_argv = ["script.py", "--config", config_file, "debug.verbose=True", "export.output_dir=/tmp"]
103+
with patch.object(sys, "argv", test_argv):
104+
with self.assertRaises(ValueError):
105+
main()
106+
finally:
107+
os.unlink(config_file)
108+
109+
110+
if __name__ == "__main__":
111+
unittest.main()
112+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# Install requirements for LLM extension
9+
pip install hydra-core>=1.3.0 omegaconf>=2.3.0

0 commit comments

Comments
 (0)