Skip to content

Commit 73c9424

Browse files
committed
Update
[ghstack-poisoned]
2 parents 78e8224 + 0cffae8 commit 73c9424

File tree

1 file changed

+19
-100
lines changed

1 file changed

+19
-100
lines changed

extension/llm/export/test/test_export_llm.py

Lines changed: 19 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,7 @@
1010
import unittest
1111
from unittest.mock import MagicMock, patch
1212

13-
from executorch.examples.models.llama.config.llm_config import (
14-
LlmConfig,
15-
ModelType,
16-
PreqMode,
17-
DtypeOverride,
18-
Pt2eQuantize,
19-
SpinQuant,
20-
CoreMLQuantize,
21-
CoreMLComputeUnit
22-
)
13+
from executorch.examples.models.llama.config.llm_config import LlmConfig
2314
from executorch.extension.llm.export.export_llm import main, parse_config_arg, pop_config_arg
2415

2516

@@ -56,9 +47,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
5647
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
5748
f.write("""
5849
base:
50+
model_class: llama2
5951
tokenizer_path: /path/to/tokenizer.json
52+
preq_mode: preq_8da4w
53+
model:
54+
dtype_override: fp16
6055
export:
6156
max_seq_length: 256
57+
quantization:
58+
pt2e_quantize: xnnpack_dynamic
59+
use_spin_quant: cuda
60+
backend:
61+
coreml:
62+
quantize: c4w
63+
compute_units: cpu_and_gpu
6264
""")
6365
config_file = f.name
6466

@@ -71,7 +73,14 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7173
mock_export_llama.assert_called_once()
7274
called_config = mock_export_llama.call_args[0][0]
7375
self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json")
76+
self.assertEqual(called_config["base"]["model_class"], "llama2")
77+
self.assertEqual(called_config["base"]["preq_mode"], "preq_8da4w")
78+
self.assertEqual(called_config["model"]["dtype_override"], "fp16")
7479
self.assertEqual(called_config["export"]["max_seq_length"], 256)
80+
self.assertEqual(called_config["quantization"]["pt2e_quantize"], "xnnpack_dynamic")
81+
self.assertEqual(called_config["quantization"]["use_spin_quant"], "cuda")
82+
self.assertEqual(called_config["backend"]["coreml"]["quantize"], "c4w")
83+
self.assertEqual(called_config["backend"]["coreml"]["compute_units"], "cpu_and_gpu")
7584
finally:
7685
os.unlink(config_file)
7786

@@ -115,96 +124,6 @@ def test_config_rejects_multiple_cli_args(self) -> None:
115124
finally:
116125
os.unlink(config_file)
117126

118-
def test_enum_fields(self) -> None:
119-
"""Test that all enum fields work correctly with their lowercase keys."""
120-
# Test ModelType enum
121-
for enum_value in ModelType:
122-
self.assertIsNotNone(enum_value.value)
123-
self.assertTrue(isinstance(enum_value.value, str))
124-
125-
# Test specific enum values that were changed from uppercase to lowercase
126-
self.assertEqual(ModelType.stories110m.value, "stories110m")
127-
self.assertEqual(ModelType.llama2.value, "llama2")
128-
self.assertEqual(ModelType.llama3.value, "llama3")
129-
self.assertEqual(ModelType.llama3_1.value, "llama3_1")
130-
self.assertEqual(ModelType.llama3_2.value, "llama3_2")
131-
self.assertEqual(ModelType.llama3_2_vision.value, "llama3_2_vision")
132-
self.assertEqual(ModelType.static_llama.value, "static_llama")
133-
self.assertEqual(ModelType.qwen2_5.value, "qwen2_5")
134-
self.assertEqual(ModelType.qwen3_0_6b.value, "qwen3-0_6b")
135-
self.assertEqual(ModelType.qwen3_1_7b.value, "qwen3-1_7b")
136-
self.assertEqual(ModelType.qwen3_4b.value, "qwen3-4b")
137-
self.assertEqual(ModelType.phi_4_mini.value, "phi_4_mini")
138-
self.assertEqual(ModelType.smollm2.value, "smollm2")
139-
140-
# Test PreqMode enum
141-
self.assertEqual(PreqMode.preq_8da4w.value, "8da4w")
142-
self.assertEqual(PreqMode.preq_8da4w_out_8da8w.value, "8da4w_output_8da8w")
143-
144-
# Test DtypeOverride enum
145-
self.assertEqual(DtypeOverride.fp32.value, "fp32")
146-
self.assertEqual(DtypeOverride.fp16.value, "fp16")
147-
self.assertEqual(DtypeOverride.bf16.value, "bf16")
148-
149-
# Test Pt2eQuantize enum
150-
self.assertEqual(Pt2eQuantize.xnnpack_dynamic.value, "xnnpack_dynamic")
151-
self.assertEqual(Pt2eQuantize.xnnpack_dynamic_qc4.value, "xnnpack_dynamic_qc4")
152-
self.assertEqual(Pt2eQuantize.qnn_8a8w.value, "qnn_8a8w")
153-
self.assertEqual(Pt2eQuantize.qnn_16a16w.value, "qnn_16a16w")
154-
self.assertEqual(Pt2eQuantize.qnn_16a4w.value, "qnn_16a4w")
155-
self.assertEqual(Pt2eQuantize.coreml_c4w.value, "coreml_c4w")
156-
self.assertEqual(Pt2eQuantize.coreml_8a_c8w.value, "coreml_8a_c8w")
157-
self.assertEqual(Pt2eQuantize.coreml_8a_c4w.value, "coreml_8a_c4w")
158-
self.assertEqual(Pt2eQuantize.coreml_baseline_8a_c8w.value, "coreml_baseline_8a_c8w")
159-
self.assertEqual(Pt2eQuantize.coreml_baseline_8a_c4w.value, "coreml_baseline_8a_c4w")
160-
self.assertEqual(Pt2eQuantize.vulkan_8w.value, "vulkan_8w")
161-
162-
# Test SpinQuant enum
163-
self.assertEqual(SpinQuant.cuda.value, "cuda")
164-
self.assertEqual(SpinQuant.native.value, "native")
165-
166-
# Test CoreMLQuantize enum
167-
self.assertEqual(CoreMLQuantize.b4w.value, "b4w")
168-
self.assertEqual(CoreMLQuantize.c4w.value, "c4w")
169-
170-
# Test CoreMLComputeUnit enum
171-
self.assertEqual(CoreMLComputeUnit.cpu_only.value, "cpu_only")
172-
self.assertEqual(CoreMLComputeUnit.cpu_and_gpu.value, "cpu_and_gpu")
173-
self.assertEqual(CoreMLComputeUnit.cpu_and_ne.value, "cpu_and_ne")
174-
self.assertEqual(CoreMLComputeUnit.all.value, "all")
175-
176-
def test_enum_configuration(self) -> None:
177-
"""Test that enum fields can be properly set in LlmConfig."""
178-
config = LlmConfig()
179-
180-
# Test setting ModelType
181-
config.base.model_class = ModelType.llama3
182-
self.assertEqual(config.base.model_class.value, "llama3")
183-
184-
# Test setting DtypeOverride
185-
config.model.dtype_override = DtypeOverride.fp16
186-
self.assertEqual(config.model.dtype_override.value, "fp16")
187-
188-
# Test setting PreqMode
189-
config.base.preq_mode = PreqMode.preq_8da4w
190-
self.assertEqual(config.base.preq_mode.value, "8da4w")
191-
192-
# Test setting Pt2eQuantize
193-
config.quantization.pt2e_quantize = Pt2eQuantize.xnnpack_dynamic
194-
self.assertEqual(config.quantization.pt2e_quantize.value, "xnnpack_dynamic")
195-
196-
# Test setting SpinQuant
197-
config.quantization.use_spin_quant = SpinQuant.cuda
198-
self.assertEqual(config.quantization.use_spin_quant.value, "cuda")
199-
200-
# Test setting CoreMLQuantize
201-
config.backend.coreml.quantize = CoreMLQuantize.c4w
202-
self.assertEqual(config.backend.coreml.quantize.value, "c4w")
203-
204-
# Test setting CoreMLComputeUnit
205-
config.backend.coreml.compute_units = CoreMLComputeUnit.cpu_and_gpu
206-
self.assertEqual(config.backend.coreml.compute_units.value, "cpu_and_gpu")
207-
208127

209128
if __name__ == "__main__":
210129
unittest.main()

0 commit comments

Comments
 (0)