Skip to content

Commit 745fad1

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1af1b27 + acd2079 commit 745fad1

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702
checkpoint=llm_config.base.checkpoint,
703703
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
704704
tokenizer_path=llm_config.base.tokenizer_path,
705-
use_spin_quant=llm_config.quantization.use_spin_quant.value if llm_config.quantization.use_spin_quant else None,
705+
use_spin_quant=(
706+
llm_config.quantization.use_spin_quant.value
707+
if llm_config.quantization.use_spin_quant
708+
else None
709+
),
706710
embedding_quantize=llm_config.quantization.embedding_quantize,
707711
use_shared_embedding=llm_config.model.use_shared_embedding,
708712
quantization_mode=llm_config.quantization.qmode,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730
vulkan=llm_config.backend.vulkan.enabled,
727731
use_qat=llm_config.quantization.use_qat,
728732
use_lora=llm_config.base.use_lora,
729-
preq_mode=llm_config.base.preq_mode.value if llm_config.base.preq_mode else None,
733+
preq_mode=(
734+
llm_config.base.preq_mode.value if llm_config.base.preq_mode else None
735+
),
730736
preq_group_size=llm_config.base.preq_group_size,
731737
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
732738
local_global_attention=llm_config.model.local_global_attention,
@@ -738,7 +744,12 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744

739745
def get_quantizer_and_quant_params(llm_config):
740746
pt2e_quant_params = get_pt2e_quantization_params(
741-
llm_config.quantization.pt2e_quantize.value if llm_config.quantization.pt2e_quantize else None, llm_config.quantization.qmode
747+
(
748+
llm_config.quantization.pt2e_quantize.value
749+
if llm_config.quantization.pt2e_quantize
750+
else None
751+
),
752+
llm_config.quantization.qmode,
742753
)
743754
quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library)
744755
quant_dtype = None
@@ -750,13 +761,17 @@ def get_quantizer_and_quant_params(llm_config):
750761
quantizers.append(qnn_quantizer)
751762
if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize:
752763
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
753-
coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize.value)
764+
coreml_quantizer = get_coreml_quantizer(
765+
llm_config.quantization.pt2e_quantize.value
766+
)
754767
quantizers.append(coreml_quantizer)
755768
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
756769
assert (
757770
len(quantizers) == 0
758771
), "Should not enable both vulkan and other quantizers"
759-
vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize.value)
772+
vulkan_quantizer = get_vulkan_quantizer(
773+
llm_config.quantization.pt2e_quantize.value
774+
)
760775
quantizers.append(vulkan_quantizer)
761776
logging.info(f"Applying quantizers: {quantizers}")
762777
return pt2e_quant_params, quantizers, quant_dtype
@@ -1076,9 +1091,17 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10761091
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
10771092
use_kv_cache=llm_config.model.use_kv_cache,
10781093
embedding_quantize=llm_config.quantization.embedding_quantize,
1079-
pt2e_quantize=llm_config.quantization.pt2e_quantize.value if llm_config.quantization.pt2e_quantize else None,
1094+
pt2e_quantize=(
1095+
llm_config.quantization.pt2e_quantize.value
1096+
if llm_config.quantization.pt2e_quantize
1097+
else None
1098+
),
10801099
coreml_ios=llm_config.backend.coreml.ios,
1081-
coreml_quantize=llm_config.backend.coreml.quantize.value if llm_config.backend.coreml.quantize else None,
1100+
coreml_quantize=(
1101+
llm_config.backend.coreml.quantize.value
1102+
if llm_config.backend.coreml.quantize
1103+
else None
1104+
),
10821105
coreml_compute_units=llm_config.backend.coreml.compute_units.value,
10831106
use_qnn_sha=llm_config.backend.qnn.use_sha,
10841107
num_sharding=llm_config.backend.qnn.num_sharding,

extension/llm/export/export_llm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@
3434
from typing import Any, List, Tuple
3535

3636
import hydra
37-
import yaml
3837

3938
from executorch.examples.models.llama.config.llm_config import LlmConfig
4039
from executorch.examples.models.llama.export_llama_lib import export_llama
4140
from hydra.core.config_store import ConfigStore
42-
from omegaconf import DictConfig, OmegaConf
41+
from omegaconf import OmegaConf
4342

4443
cs = ConfigStore.instance()
4544
cs.store(name="llm_config", node=LlmConfig)
@@ -79,7 +78,7 @@ def main() -> None:
7978
"Cannot specify additional CLI arguments when using --config. "
8079
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
8180
)
82-
81+
8382
config_file_path = pop_config_arg()
8483
default_llm_config = LlmConfig()
8584
llm_config_from_file = OmegaConf.load(config_file_path)

extension/llm/export/test/test_export_llm.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
import unittest
1111
from unittest.mock import MagicMock, patch
1212

13-
from executorch.examples.models.llama.config.llm_config import LlmConfig
14-
from executorch.extension.llm.export.export_llm import main, parse_config_arg, pop_config_arg
13+
from executorch.extension.llm.export.export_llm import (
14+
main,
15+
parse_config_arg,
16+
pop_config_arg,
17+
)
1518

1619

1720
class TestExportLlm(unittest.TestCase):
@@ -45,7 +48,8 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
4548
"""Test main function with --config file and no hydra args."""
4649
# Create a temporary config file
4750
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
48-
f.write("""
51+
f.write(
52+
"""
4953
base:
5054
model_class: llama2
5155
tokenizer_path: /path/to/tokenizer.json
@@ -61,7 +65,8 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
6165
coreml:
6266
quantize: c4w
6367
compute_units: cpu_and_gpu
64-
""")
68+
"""
69+
)
6570
config_file = f.name
6671

6772
try:
@@ -72,23 +77,35 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7277
# Verify export_llama was called with config
7378
mock_export_llama.assert_called_once()
7479
called_config = mock_export_llama.call_args[0][0]
75-
self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json")
80+
self.assertEqual(
81+
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
82+
)
7683
self.assertEqual(called_config["base"]["model_class"], "llama2")
7784
self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w")
7885
self.assertEqual(called_config["model"]["dtype_override"].value, "fp16")
7986
self.assertEqual(called_config["export"]["max_seq_length"], 256)
80-
self.assertEqual(called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic")
81-
self.assertEqual(called_config["quantization"]["use_spin_quant"].value, "cuda")
82-
self.assertEqual(called_config["backend"]["coreml"]["quantize"].value, "c4w")
83-
self.assertEqual(called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu")
87+
self.assertEqual(
88+
called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic"
89+
)
90+
self.assertEqual(
91+
called_config["quantization"]["use_spin_quant"].value, "cuda"
92+
)
93+
self.assertEqual(
94+
called_config["backend"]["coreml"]["quantize"].value, "c4w"
95+
)
96+
self.assertEqual(
97+
called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu"
98+
)
8499
finally:
85100
os.unlink(config_file)
86101

87102
def test_with_cli_args(self) -> None:
88103
"""Test main function with only hydra CLI args."""
89104
test_argv = ["script.py", "debug.verbose=True"]
90105
with patch.object(sys, "argv", test_argv):
91-
with patch("executorch.extension.llm.export.export_llm.hydra_main") as mock_hydra:
106+
with patch(
107+
"executorch.extension.llm.export.export_llm.hydra_main"
108+
) as mock_hydra:
92109
main()
93110
mock_hydra.assert_called_once()
94111

@@ -104,9 +121,12 @@ def test_config_with_cli_args_error(self) -> None:
104121
with patch.object(sys, "argv", test_argv):
105122
with self.assertRaises(ValueError) as cm:
106123
main()
107-
124+
108125
error_msg = str(cm.exception)
109-
self.assertIn("Cannot specify additional CLI arguments when using --config", error_msg)
126+
self.assertIn(
127+
"Cannot specify additional CLI arguments when using --config",
128+
error_msg,
129+
)
110130
finally:
111131
os.unlink(config_file)
112132

@@ -117,7 +137,13 @@ def test_config_rejects_multiple_cli_args(self) -> None:
117137
config_file = f.name
118138

119139
try:
120-
test_argv = ["script.py", "--config", config_file, "debug.verbose=True", "export.output_dir=/tmp"]
140+
test_argv = [
141+
"script.py",
142+
"--config",
143+
config_file,
144+
"debug.verbose=True",
145+
"export.output_dir=/tmp",
146+
]
121147
with patch.object(sys, "argv", test_argv):
122148
with self.assertRaises(ValueError):
123149
main()
@@ -127,4 +153,3 @@ def test_config_rejects_multiple_cli_args(self) -> None:
127153

128154
if __name__ == "__main__":
129155
unittest.main()
130-

0 commit comments

Comments
 (0)