1010import unittest
1111from unittest .mock import MagicMock , patch
1212
13- from executorch .examples .models .llama .config .llm_config import (
14- LlmConfig ,
15- ModelType ,
16- PreqMode ,
17- DtypeOverride ,
18- Pt2eQuantize ,
19- SpinQuant ,
20- CoreMLQuantize ,
21- CoreMLComputeUnit
22- )
13+ from executorch .examples .models .llama .config .llm_config import LlmConfig
2314from executorch .extension .llm .export .export_llm import main , parse_config_arg , pop_config_arg
2415
2516
@@ -56,9 +47,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
5647 with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
5748 f .write ("""
5849base:
50+ model_class: llama2
5951 tokenizer_path: /path/to/tokenizer.json
52+ preq_mode: preq_8da4w
53+ model:
54+ dtype_override: fp16
6055export:
6156 max_seq_length: 256
57+ quantization:
58+ pt2e_quantize: xnnpack_dynamic
59+ use_spin_quant: cuda
60+ backend:
61+ coreml:
62+ quantize: c4w
63+ compute_units: cpu_and_gpu
6264""" )
6365 config_file = f .name
6466
@@ -71,7 +73,14 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7173 mock_export_llama .assert_called_once ()
7274 called_config = mock_export_llama .call_args [0 ][0 ]
7375 self .assertEqual (called_config ["base" ]["tokenizer_path" ], "/path/to/tokenizer.json" )
76+ self .assertEqual (called_config ["base" ]["model_class" ], "llama2" )
77+ self .assertEqual (called_config ["base" ]["preq_mode" ], "preq_8da4w" )
78+ self .assertEqual (called_config ["model" ]["dtype_override" ], "fp16" )
7479 self .assertEqual (called_config ["export" ]["max_seq_length" ], 256 )
80+ self .assertEqual (called_config ["quantization" ]["pt2e_quantize" ], "xnnpack_dynamic" )
81+ self .assertEqual (called_config ["quantization" ]["use_spin_quant" ], "cuda" )
82+ self .assertEqual (called_config ["backend" ]["coreml" ]["quantize" ], "c4w" )
83+ self .assertEqual (called_config ["backend" ]["coreml" ]["compute_units" ], "cpu_and_gpu" )
7584 finally :
7685 os .unlink (config_file )
7786
@@ -115,96 +124,6 @@ def test_config_rejects_multiple_cli_args(self) -> None:
115124 finally :
116125 os .unlink (config_file )
117126
118- def test_enum_fields (self ) -> None :
119- """Test that all enum fields work correctly with their lowercase keys."""
120- # Test ModelType enum
121- for enum_value in ModelType :
122- self .assertIsNotNone (enum_value .value )
123- self .assertTrue (isinstance (enum_value .value , str ))
124-
125- # Test specific enum values that were changed from uppercase to lowercase
126- self .assertEqual (ModelType .stories110m .value , "stories110m" )
127- self .assertEqual (ModelType .llama2 .value , "llama2" )
128- self .assertEqual (ModelType .llama3 .value , "llama3" )
129- self .assertEqual (ModelType .llama3_1 .value , "llama3_1" )
130- self .assertEqual (ModelType .llama3_2 .value , "llama3_2" )
131- self .assertEqual (ModelType .llama3_2_vision .value , "llama3_2_vision" )
132- self .assertEqual (ModelType .static_llama .value , "static_llama" )
133- self .assertEqual (ModelType .qwen2_5 .value , "qwen2_5" )
134- self .assertEqual (ModelType .qwen3_0_6b .value , "qwen3-0_6b" )
135- self .assertEqual (ModelType .qwen3_1_7b .value , "qwen3-1_7b" )
136- self .assertEqual (ModelType .qwen3_4b .value , "qwen3-4b" )
137- self .assertEqual (ModelType .phi_4_mini .value , "phi_4_mini" )
138- self .assertEqual (ModelType .smollm2 .value , "smollm2" )
139-
140- # Test PreqMode enum
141- self .assertEqual (PreqMode .preq_8da4w .value , "8da4w" )
142- self .assertEqual (PreqMode .preq_8da4w_out_8da8w .value , "8da4w_output_8da8w" )
143-
144- # Test DtypeOverride enum
145- self .assertEqual (DtypeOverride .fp32 .value , "fp32" )
146- self .assertEqual (DtypeOverride .fp16 .value , "fp16" )
147- self .assertEqual (DtypeOverride .bf16 .value , "bf16" )
148-
149- # Test Pt2eQuantize enum
150- self .assertEqual (Pt2eQuantize .xnnpack_dynamic .value , "xnnpack_dynamic" )
151- self .assertEqual (Pt2eQuantize .xnnpack_dynamic_qc4 .value , "xnnpack_dynamic_qc4" )
152- self .assertEqual (Pt2eQuantize .qnn_8a8w .value , "qnn_8a8w" )
153- self .assertEqual (Pt2eQuantize .qnn_16a16w .value , "qnn_16a16w" )
154- self .assertEqual (Pt2eQuantize .qnn_16a4w .value , "qnn_16a4w" )
155- self .assertEqual (Pt2eQuantize .coreml_c4w .value , "coreml_c4w" )
156- self .assertEqual (Pt2eQuantize .coreml_8a_c8w .value , "coreml_8a_c8w" )
157- self .assertEqual (Pt2eQuantize .coreml_8a_c4w .value , "coreml_8a_c4w" )
158- self .assertEqual (Pt2eQuantize .coreml_baseline_8a_c8w .value , "coreml_baseline_8a_c8w" )
159- self .assertEqual (Pt2eQuantize .coreml_baseline_8a_c4w .value , "coreml_baseline_8a_c4w" )
160- self .assertEqual (Pt2eQuantize .vulkan_8w .value , "vulkan_8w" )
161-
162- # Test SpinQuant enum
163- self .assertEqual (SpinQuant .cuda .value , "cuda" )
164- self .assertEqual (SpinQuant .native .value , "native" )
165-
166- # Test CoreMLQuantize enum
167- self .assertEqual (CoreMLQuantize .b4w .value , "b4w" )
168- self .assertEqual (CoreMLQuantize .c4w .value , "c4w" )
169-
170- # Test CoreMLComputeUnit enum
171- self .assertEqual (CoreMLComputeUnit .cpu_only .value , "cpu_only" )
172- self .assertEqual (CoreMLComputeUnit .cpu_and_gpu .value , "cpu_and_gpu" )
173- self .assertEqual (CoreMLComputeUnit .cpu_and_ne .value , "cpu_and_ne" )
174- self .assertEqual (CoreMLComputeUnit .all .value , "all" )
175-
176- def test_enum_configuration (self ) -> None :
177- """Test that enum fields can be properly set in LlmConfig."""
178- config = LlmConfig ()
179-
180- # Test setting ModelType
181- config .base .model_class = ModelType .llama3
182- self .assertEqual (config .base .model_class .value , "llama3" )
183-
184- # Test setting DtypeOverride
185- config .model .dtype_override = DtypeOverride .fp16
186- self .assertEqual (config .model .dtype_override .value , "fp16" )
187-
188- # Test setting PreqMode
189- config .base .preq_mode = PreqMode .preq_8da4w
190- self .assertEqual (config .base .preq_mode .value , "8da4w" )
191-
192- # Test setting Pt2eQuantize
193- config .quantization .pt2e_quantize = Pt2eQuantize .xnnpack_dynamic
194- self .assertEqual (config .quantization .pt2e_quantize .value , "xnnpack_dynamic" )
195-
196- # Test setting SpinQuant
197- config .quantization .use_spin_quant = SpinQuant .cuda
198- self .assertEqual (config .quantization .use_spin_quant .value , "cuda" )
199-
200- # Test setting CoreMLQuantize
201- config .backend .coreml .quantize = CoreMLQuantize .c4w
202- self .assertEqual (config .backend .coreml .quantize .value , "c4w" )
203-
204- # Test setting CoreMLComputeUnit
205- config .backend .coreml .compute_units = CoreMLComputeUnit .cpu_and_gpu
206- self .assertEqual (config .backend .coreml .compute_units .value , "cpu_and_gpu" )
207-
208127
209128if __name__ == "__main__" :
210129 unittest .main ()
0 commit comments