diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index 201e3a5414a..0504b386f45 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -16,7 +16,6 @@ import ast import re from dataclasses import dataclass, field -from enum import Enum from typing import ClassVar, List, Optional @@ -25,32 +24,27 @@ ################################################################################ -class ModelType(str, Enum): - STORIES110M = "stories110m" - LLAMA2 = "llama2" - LLAMA3 = "llama3" - LLAMA3_1 = "llama3_1" - LLAMA3_2 = "llama3_2" - LLAMA3_2_VISION = "llama3_2_vision" - STATIC_LLAMA = "static_llama" - QWEN2_5 = "qwen2_5" - QWEN3_0_6B = "qwen3-0_6b" - QWEN3_1_7B = "qwen3-1_7b" - QWEN3_4B = "qwen3-4b" - PHI_4_MINI = "phi_4_mini" - SMOLLM2 = "smollm2" +MODEL_TYPE_OPTIONS = [ + "stories110m", + "llama2", + "llama3", + "llama3_1", + "llama3_2", + "llama3_2_vision", + "static_llama", + "qwen2_5", + "qwen3-0_6b", + "qwen3-1_7b", + "qwen3-4b", + "phi_4_mini", + "smollm2", +] -class PreqMode(str, Enum): - """ - If you are dealing with pre-quantized checkpoints, this used to - be the way to specify them. Now you don't need to specify these - options if you use a TorchAo-prequantized checkpoint, but they - are still around to preserve backward compatibility. - """ - - PREQ_8DA4W = "8da4w" - PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" +PREQ_MODE_OPTIONS = [ + "8da4w", + "8da4w_output_8da8w", +] @dataclass @@ -82,7 +76,7 @@ class BaseConfig: are loaded. """ - model_class: ModelType = ModelType.LLAMA3 + model_class: str = "llama3" params: Optional[str] = None checkpoint: Optional[str] = None checkpoint_dir: Optional[str] = None @@ -90,26 +84,28 @@ class BaseConfig: metadata: Optional[str] = None use_lora: int = 0 fairseq2: bool = False - preq_mode: Optional[PreqMode] = None + preq_mode: Optional[str] = None preq_group_size: int = 32 preq_embedding_quantize: str = "8,0" + def __post_init__(self): + if self.model_class not in MODEL_TYPE_OPTIONS: + raise ValueError(f"model_class must be one of {MODEL_TYPE_OPTIONS}, got '{self.model_class}'") + + if self.preq_mode is not None and self.preq_mode not in PREQ_MODE_OPTIONS: + raise ValueError(f"preq_mode must be one of {PREQ_MODE_OPTIONS}, got '{self.preq_mode}'") + ################################################################################ ################################# ModelConfig ################################## ################################################################################ -class DtypeOverride(str, Enum): - """ - DType of the model. Highly recommended to use "fp32", unless you want to - export without a backend, in which case you can also use "bf16". "fp16" - is not recommended. - """ - - FP32 = "fp32" - FP16 = "fp16" - BF16 = "bf16" +DTYPE_OVERRIDE_OPTIONS = [ + "fp32", + "fp16", + "bf16", +] @dataclass @@ -147,7 +143,7 @@ class ModelConfig: [16] pattern specifies all layers have a sliding window of 16. """ - dtype_override: DtypeOverride = DtypeOverride.FP32 + dtype_override: str = "fp32" enable_dynamic_shape: bool = True use_shared_embedding: bool = False use_sdpa_with_kv_cache: bool = False @@ -160,6 +156,9 @@ class ModelConfig: local_global_attention: Optional[List[int]] = None def __post_init__(self): + if self.dtype_override not in DTYPE_OVERRIDE_OPTIONS: + raise ValueError(f"dtype_override must be one of {DTYPE_OVERRIDE_OPTIONS}, got '{self.dtype_override}'") + self._validate_attention_sink() self._validate_local_global_attention() @@ -261,31 +260,25 @@ class DebugConfig: ################################################################################ -class Pt2eQuantize(str, Enum): - """ - Type of backend-specific Pt2e quantization strategy to use. - - Pt2e uses a different quantization library that is graph-based - compared to `qmode`, which is also specified in the QuantizationConfig - and is source transform-based. - """ +PT2E_QUANTIZE_OPTIONS = [ + "xnnpack_dynamic", + "xnnpack_dynamic_qc4", + "qnn_8a8w", + "qnn_16a16w", + "qnn_16a4w", + "coreml_c4w", + "coreml_8a_c8w", + "coreml_8a_c4w", + "coreml_baseline_8a_c8w", + "coreml_baseline_8a_c4w", + "vulkan_8w", +] - XNNPACK_DYNAMIC = "xnnpack_dynamic" - XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" - QNN_8A8W = "qnn_8a8w" - QNN_16A16W = "qnn_16a16w" - QNN_16A4W = "qnn_16a4w" - COREML_C4W = "coreml_c4w" - COREML_8A_C8W = "coreml_8a_c8w" - COREML_8A_C4W = "coreml_8a_c4w" - COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" - COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" - VULKAN_8W = "vulkan_8w" - -class SpinQuant(str, Enum): - CUDA = "cuda" - NATIVE = "native" +SPIN_QUANT_OPTIONS = [ + "cuda", + "native", +] @dataclass @@ -320,9 +313,9 @@ class QuantizationConfig: qmode: Optional[str] = None embedding_quantize: Optional[str] = None - pt2e_quantize: Optional[Pt2eQuantize] = None + pt2e_quantize: Optional[str] = None group_size: Optional[int] = None - use_spin_quant: Optional[SpinQuant] = None + use_spin_quant: Optional[str] = None use_qat: bool = False calibration_tasks: Optional[List[str]] = None calibration_limit: Optional[int] = None @@ -330,6 +323,12 @@ class QuantizationConfig: calibration_data: str = "Once upon a time" def __post_init__(self): + if self.pt2e_quantize is not None and self.pt2e_quantize not in PT2E_QUANTIZE_OPTIONS: + raise ValueError(f"pt2e_quantize must be one of {PT2E_QUANTIZE_OPTIONS}, got '{self.pt2e_quantize}'") + + if self.use_spin_quant is not None and self.use_spin_quant not in SPIN_QUANT_OPTIONS: + raise ValueError(f"use_spin_quant must be one of {SPIN_QUANT_OPTIONS}, got '{self.use_spin_quant}'") + if self.qmode: self._validate_qmode() @@ -377,16 +376,18 @@ class XNNPackConfig: extended_ops: bool = False -class CoreMLQuantize(str, Enum): - B4W = "b4w" - C4W = "c4w" +COREML_QUANTIZE_OPTIONS = [ + "b4w", + "c4w", +] -class CoreMLComputeUnit(str, Enum): - CPU_ONLY = "cpu_only" - CPU_AND_GPU = "cpu_and_gpu" - CPU_AND_NE = "cpu_and_ne" - ALL = "all" +COREML_COMPUTE_UNIT_OPTIONS = [ + "cpu_only", + "cpu_and_gpu", + "cpu_and_ne", + "all", +] @dataclass @@ -398,11 +399,17 @@ class CoreMLConfig: enabled: bool = False enable_state: bool = False preserve_sdpa: bool = False - quantize: Optional[CoreMLQuantize] = None + quantize: Optional[str] = None ios: int = 15 - compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY + compute_units: str = "cpu_only" def __post_init__(self): + if self.quantize is not None and self.quantize not in COREML_QUANTIZE_OPTIONS: + raise ValueError(f"quantize must be one of {COREML_QUANTIZE_OPTIONS}, got '{self.quantize}'") + + if self.compute_units not in COREML_COMPUTE_UNIT_OPTIONS: + raise ValueError(f"compute_units must be one of {COREML_COMPUTE_UNIT_OPTIONS}, got '{self.compute_units}'") + if self.ios not in (15, 16, 17, 18): raise ValueError(f"Invalid coreml ios version: {self.ios}") @@ -481,7 +488,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 # BaseConfig if hasattr(args, "model"): - llm_config.base.model_class = ModelType(args.model) + llm_config.base.model_class = args.model if hasattr(args, "params"): llm_config.base.params = args.params if hasattr(args, "checkpoint"): @@ -499,7 +506,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 # PreqMode settings if hasattr(args, "preq_mode") and args.preq_mode: - llm_config.base.preq_mode = PreqMode(args.preq_mode) + llm_config.base.preq_mode = args.preq_mode if hasattr(args, "preq_group_size"): llm_config.base.preq_group_size = args.preq_group_size if hasattr(args, "preq_embedding_quantize"): @@ -507,7 +514,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 # ModelConfig if hasattr(args, "dtype_override"): - llm_config.model.dtype_override = DtypeOverride(args.dtype_override) + llm_config.model.dtype_override = args.dtype_override if hasattr(args, "enable_dynamic_shape"): llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape if hasattr(args, "use_shared_embedding"): @@ -549,11 +556,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "embedding_quantize"): llm_config.quantization.embedding_quantize = args.embedding_quantize if hasattr(args, "pt2e_quantize") and args.pt2e_quantize: - llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize) + llm_config.quantization.pt2e_quantize = args.pt2e_quantize if hasattr(args, "group_size"): llm_config.quantization.group_size = args.group_size if hasattr(args, "use_spin_quant") and args.use_spin_quant: - llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant) + llm_config.quantization.use_spin_quant = args.use_spin_quant if hasattr(args, "use_qat"): llm_config.quantization.use_qat = args.use_qat if hasattr(args, "calibration_tasks"): @@ -581,13 +588,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 args, "coreml_preserve_sdpa", False ) if hasattr(args, "coreml_quantize") and args.coreml_quantize: - llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize) + llm_config.backend.coreml.quantize = args.coreml_quantize if hasattr(args, "coreml_ios"): llm_config.backend.coreml.ios = args.coreml_ios if hasattr(args, "coreml_compute_units"): - llm_config.backend.coreml.compute_units = CoreMLComputeUnit( - args.coreml_compute_units - ) + llm_config.backend.coreml.compute_units = args.coreml_compute_units # Vulkan if hasattr(args, "vulkan"): diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py index 0853e9dbbd8..15513bcd6f2 100644 --- a/examples/models/llama/config/test_llm_config.py +++ b/examples/models/llama/config/test_llm_config.py @@ -11,7 +11,6 @@ from executorch.examples.models.llama.config.llm_config import ( BackendConfig, BaseConfig, - CoreMLComputeUnit, CoreMLConfig, DebugConfig, ExportConfig, @@ -66,6 +65,34 @@ def test_shared_embedding_without_lowbit(self): with self.assertRaises(ValueError): LlmConfig(model=model_cfg, quantization=qcfg) + def test_invalid_model_type(self): + with self.assertRaises(ValueError): + BaseConfig(model_class="invalid_model") + + def test_invalid_dtype_override(self): + with self.assertRaises(ValueError): + ModelConfig(dtype_override="invalid_dtype") + + def test_invalid_preq_mode(self): + with self.assertRaises(ValueError): + BaseConfig(preq_mode="invalid_preq") + + def test_invalid_pt2e_quantize(self): + with self.assertRaises(ValueError): + QuantizationConfig(pt2e_quantize="invalid_pt2e") + + def test_invalid_spin_quant(self): + with self.assertRaises(ValueError): + QuantizationConfig(use_spin_quant="invalid_spin") + + def test_invalid_coreml_quantize(self): + with self.assertRaises(ValueError): + CoreMLConfig(quantize="invalid_quantize") + + def test_invalid_coreml_compute_units(self): + with self.assertRaises(ValueError): + CoreMLConfig(compute_units="invalid_compute_units") + class TestValidConstruction(unittest.TestCase): @@ -94,7 +121,7 @@ def test_valid_llm_config(self): backend=BackendConfig( xnnpack=XNNPackConfig(enabled=False), coreml=CoreMLConfig( - enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL + enabled=True, ios=17, compute_units="all" ), ), ) diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 970a32c9606..258a867dc6b 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -47,9 +47,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write(""" base: + model_class: llama3 tokenizer_path: /path/to/tokenizer.json + preq_mode: 8da4w +model: + dtype_override: fp32 export: max_seq_length: 256 +quantization: + pt2e_quantize: xnnpack_dynamic + use_spin_quant: cuda +backend: + coreml: + quantize: c4w + compute_units: cpu_and_gpu """) config_file = f.name @@ -61,8 +72,15 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: # Verify export_llama was called with config mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] + self.assertEqual(called_config["base"]["model_class"], "llama3") self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json") + self.assertEqual(called_config["base"]["preq_mode"], "8da4w") + self.assertEqual(called_config["model"]["dtype_override"], "fp32") self.assertEqual(called_config["export"]["max_seq_length"], 256) + self.assertEqual(called_config["quantization"]["pt2e_quantize"], "xnnpack_dynamic") + self.assertEqual(called_config["quantization"]["use_spin_quant"], "cuda") + self.assertEqual(called_config["backend"]["coreml"]["quantize"], "c4w") + self.assertEqual(called_config["backend"]["coreml"]["compute_units"], "cpu_and_gpu") finally: os.unlink(config_file) @@ -78,7 +96,13 @@ def test_config_with_cli_args_error(self) -> None: """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" # Create a temporary config file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("base:\n checkpoint: /path/to/checkpoint.pth") + f.write(""" +base: + model_class: llama2 + checkpoint: /path/to/checkpoint.pth +model: + dtype_override: bf16 +""") config_file = f.name try: @@ -95,7 +119,14 @@ def test_config_with_cli_args_error(self) -> None: def test_config_rejects_multiple_cli_args(self) -> None: """Test that --config rejects multiple CLI arguments (not just single ones).""" with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("export:\n max_seq_length: 128") + f.write(""" +base: + model_class: qwen2_5 +export: + max_seq_length: 128 +quantization: + pt2e_quantize: qnn_8a8w +""") config_file = f.name try: