Skip to content

Commit 37016c2

Browse files
committed
Get rid of llm_config enums
ghstack-source-id: df9d181 ghstack-comment-id: 2986530962 Pull-Request: pytorch/executorch#11810
1 parent a147d30 commit 37016c2

File tree

3 files changed

+149
-86
lines changed

3 files changed

+149
-86
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 87 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import ast
1717
import re
1818
from dataclasses import dataclass, field
19-
from enum import Enum
2019
from typing import ClassVar, List, Optional
2120

2221

@@ -25,32 +24,27 @@
2524
################################################################################
2625

2726

28-
class ModelType(str, Enum):
29-
STORIES110M = "stories110m"
30-
LLAMA2 = "llama2"
31-
LLAMA3 = "llama3"
32-
LLAMA3_1 = "llama3_1"
33-
LLAMA3_2 = "llama3_2"
34-
LLAMA3_2_VISION = "llama3_2_vision"
35-
STATIC_LLAMA = "static_llama"
36-
QWEN2_5 = "qwen2_5"
37-
QWEN3_0_6B = "qwen3-0_6b"
38-
QWEN3_1_7B = "qwen3-1_7b"
39-
QWEN3_4B = "qwen3-4b"
40-
PHI_4_MINI = "phi_4_mini"
41-
SMOLLM2 = "smollm2"
27+
MODEL_TYPE_OPTIONS = [
28+
"stories110m",
29+
"llama2",
30+
"llama3",
31+
"llama3_1",
32+
"llama3_2",
33+
"llama3_2_vision",
34+
"static_llama",
35+
"qwen2_5",
36+
"qwen3-0_6b",
37+
"qwen3-1_7b",
38+
"qwen3-4b",
39+
"phi_4_mini",
40+
"smollm2",
41+
]
4242

4343

44-
class PreqMode(str, Enum):
45-
"""
46-
If you are dealing with pre-quantized checkpoints, this used to
47-
be the way to specify them. Now you don't need to specify these
48-
options if you use a TorchAo-prequantized checkpoint, but they
49-
are still around to preserve backward compatibility.
50-
"""
51-
52-
PREQ_8DA4W = "8da4w"
53-
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
44+
PREQ_MODE_OPTIONS = [
45+
"8da4w",
46+
"8da4w_output_8da8w",
47+
]
5448

5549

5650
@dataclass
@@ -82,34 +76,36 @@ class BaseConfig:
8276
are loaded.
8377
"""
8478

85-
model_class: ModelType = ModelType.LLAMA3
79+
model_class: str = "llama3"
8680
params: Optional[str] = None
8781
checkpoint: Optional[str] = None
8882
checkpoint_dir: Optional[str] = None
8983
tokenizer_path: Optional[str] = None
9084
metadata: Optional[str] = None
9185
use_lora: int = 0
9286
fairseq2: bool = False
93-
preq_mode: Optional[PreqMode] = None
87+
preq_mode: Optional[str] = None
9488
preq_group_size: int = 32
9589
preq_embedding_quantize: str = "8,0"
9690

91+
def __post_init__(self):
92+
if self.model_class not in MODEL_TYPE_OPTIONS:
93+
raise ValueError(f"model_class must be one of {MODEL_TYPE_OPTIONS}, got '{self.model_class}'")
94+
95+
if self.preq_mode is not None and self.preq_mode not in PREQ_MODE_OPTIONS:
96+
raise ValueError(f"preq_mode must be one of {PREQ_MODE_OPTIONS}, got '{self.preq_mode}'")
97+
9798

9899
################################################################################
99100
################################# ModelConfig ##################################
100101
################################################################################
101102

102103

103-
class DtypeOverride(str, Enum):
104-
"""
105-
DType of the model. Highly recommended to use "fp32", unless you want to
106-
export without a backend, in which case you can also use "bf16". "fp16"
107-
is not recommended.
108-
"""
109-
110-
FP32 = "fp32"
111-
FP16 = "fp16"
112-
BF16 = "bf16"
104+
DTYPE_OVERRIDE_OPTIONS = [
105+
"fp32",
106+
"fp16",
107+
"bf16",
108+
]
113109

114110

115111
@dataclass
@@ -147,7 +143,7 @@ class ModelConfig:
147143
[16] pattern specifies all layers have a sliding window of 16.
148144
"""
149145

150-
dtype_override: DtypeOverride = DtypeOverride.FP32
146+
dtype_override: str = "fp32"
151147
enable_dynamic_shape: bool = True
152148
use_shared_embedding: bool = False
153149
use_sdpa_with_kv_cache: bool = False
@@ -160,6 +156,9 @@ class ModelConfig:
160156
local_global_attention: Optional[List[int]] = None
161157

162158
def __post_init__(self):
159+
if self.dtype_override not in DTYPE_OVERRIDE_OPTIONS:
160+
raise ValueError(f"dtype_override must be one of {DTYPE_OVERRIDE_OPTIONS}, got '{self.dtype_override}'")
161+
163162
self._validate_attention_sink()
164163
self._validate_local_global_attention()
165164

@@ -261,31 +260,25 @@ class DebugConfig:
261260
################################################################################
262261

263262

264-
class Pt2eQuantize(str, Enum):
265-
"""
266-
Type of backend-specific Pt2e quantization strategy to use.
267-
268-
Pt2e uses a different quantization library that is graph-based
269-
compared to `qmode`, which is also specified in the QuantizationConfig
270-
and is source transform-based.
271-
"""
263+
PT2E_QUANTIZE_OPTIONS = [
264+
"xnnpack_dynamic",
265+
"xnnpack_dynamic_qc4",
266+
"qnn_8a8w",
267+
"qnn_16a16w",
268+
"qnn_16a4w",
269+
"coreml_c4w",
270+
"coreml_8a_c8w",
271+
"coreml_8a_c4w",
272+
"coreml_baseline_8a_c8w",
273+
"coreml_baseline_8a_c4w",
274+
"vulkan_8w",
275+
]
272276

273-
XNNPACK_DYNAMIC = "xnnpack_dynamic"
274-
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
275-
QNN_8A8W = "qnn_8a8w"
276-
QNN_16A16W = "qnn_16a16w"
277-
QNN_16A4W = "qnn_16a4w"
278-
COREML_C4W = "coreml_c4w"
279-
COREML_8A_C8W = "coreml_8a_c8w"
280-
COREML_8A_C4W = "coreml_8a_c4w"
281-
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
282-
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
283-
VULKAN_8W = "vulkan_8w"
284277

285-
286-
class SpinQuant(str, Enum):
287-
CUDA = "cuda"
288-
NATIVE = "native"
278+
SPIN_QUANT_OPTIONS = [
279+
"cuda",
280+
"native",
281+
]
289282

290283

291284
@dataclass
@@ -320,16 +313,22 @@ class QuantizationConfig:
320313

321314
qmode: Optional[str] = None
322315
embedding_quantize: Optional[str] = None
323-
pt2e_quantize: Optional[Pt2eQuantize] = None
316+
pt2e_quantize: Optional[str] = None
324317
group_size: Optional[int] = None
325-
use_spin_quant: Optional[SpinQuant] = None
318+
use_spin_quant: Optional[str] = None
326319
use_qat: bool = False
327320
calibration_tasks: Optional[List[str]] = None
328321
calibration_limit: Optional[int] = None
329322
calibration_seq_length: Optional[int] = None
330323
calibration_data: str = "Once upon a time"
331324

332325
def __post_init__(self):
326+
if self.pt2e_quantize is not None and self.pt2e_quantize not in PT2E_QUANTIZE_OPTIONS:
327+
raise ValueError(f"pt2e_quantize must be one of {PT2E_QUANTIZE_OPTIONS}, got '{self.pt2e_quantize}'")
328+
329+
if self.use_spin_quant is not None and self.use_spin_quant not in SPIN_QUANT_OPTIONS:
330+
raise ValueError(f"use_spin_quant must be one of {SPIN_QUANT_OPTIONS}, got '{self.use_spin_quant}'")
331+
333332
if self.qmode:
334333
self._validate_qmode()
335334

@@ -377,16 +376,18 @@ class XNNPackConfig:
377376
extended_ops: bool = False
378377

379378

380-
class CoreMLQuantize(str, Enum):
381-
B4W = "b4w"
382-
C4W = "c4w"
379+
COREML_QUANTIZE_OPTIONS = [
380+
"b4w",
381+
"c4w",
382+
]
383383

384384

385-
class CoreMLComputeUnit(str, Enum):
386-
CPU_ONLY = "cpu_only"
387-
CPU_AND_GPU = "cpu_and_gpu"
388-
CPU_AND_NE = "cpu_and_ne"
389-
ALL = "all"
385+
COREML_COMPUTE_UNIT_OPTIONS = [
386+
"cpu_only",
387+
"cpu_and_gpu",
388+
"cpu_and_ne",
389+
"all",
390+
]
390391

391392

392393
@dataclass
@@ -398,11 +399,17 @@ class CoreMLConfig:
398399
enabled: bool = False
399400
enable_state: bool = False
400401
preserve_sdpa: bool = False
401-
quantize: Optional[CoreMLQuantize] = None
402+
quantize: Optional[str] = None
402403
ios: int = 15
403-
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
404+
compute_units: str = "cpu_only"
404405

405406
def __post_init__(self):
407+
if self.quantize is not None and self.quantize not in COREML_QUANTIZE_OPTIONS:
408+
raise ValueError(f"quantize must be one of {COREML_QUANTIZE_OPTIONS}, got '{self.quantize}'")
409+
410+
if self.compute_units not in COREML_COMPUTE_UNIT_OPTIONS:
411+
raise ValueError(f"compute_units must be one of {COREML_COMPUTE_UNIT_OPTIONS}, got '{self.compute_units}'")
412+
406413
if self.ios not in (15, 16, 17, 18):
407414
raise ValueError(f"Invalid coreml ios version: {self.ios}")
408415

@@ -481,7 +488,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
481488

482489
# BaseConfig
483490
if hasattr(args, "model"):
484-
llm_config.base.model_class = ModelType(args.model)
491+
llm_config.base.model_class = args.model
485492
if hasattr(args, "params"):
486493
llm_config.base.params = args.params
487494
if hasattr(args, "checkpoint"):
@@ -499,15 +506,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
499506

500507
# PreqMode settings
501508
if hasattr(args, "preq_mode") and args.preq_mode:
502-
llm_config.base.preq_mode = PreqMode(args.preq_mode)
509+
llm_config.base.preq_mode = args.preq_mode
503510
if hasattr(args, "preq_group_size"):
504511
llm_config.base.preq_group_size = args.preq_group_size
505512
if hasattr(args, "preq_embedding_quantize"):
506513
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
507514

508515
# ModelConfig
509516
if hasattr(args, "dtype_override"):
510-
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
517+
llm_config.model.dtype_override = args.dtype_override
511518
if hasattr(args, "enable_dynamic_shape"):
512519
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
513520
if hasattr(args, "use_shared_embedding"):
@@ -549,11 +556,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
549556
if hasattr(args, "embedding_quantize"):
550557
llm_config.quantization.embedding_quantize = args.embedding_quantize
551558
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
552-
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
559+
llm_config.quantization.pt2e_quantize = args.pt2e_quantize
553560
if hasattr(args, "group_size"):
554561
llm_config.quantization.group_size = args.group_size
555562
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
556-
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
563+
llm_config.quantization.use_spin_quant = args.use_spin_quant
557564
if hasattr(args, "use_qat"):
558565
llm_config.quantization.use_qat = args.use_qat
559566
if hasattr(args, "calibration_tasks"):
@@ -581,13 +588,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
581588
args, "coreml_preserve_sdpa", False
582589
)
583590
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
584-
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
591+
llm_config.backend.coreml.quantize = args.coreml_quantize
585592
if hasattr(args, "coreml_ios"):
586593
llm_config.backend.coreml.ios = args.coreml_ios
587594
if hasattr(args, "coreml_compute_units"):
588-
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
589-
args.coreml_compute_units
590-
)
595+
llm_config.backend.coreml.compute_units = args.coreml_compute_units
591596

592597
# Vulkan
593598
if hasattr(args, "vulkan"):

examples/models/llama/config/test_llm_config.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from executorch.examples.models.llama.config.llm_config import (
1212
BackendConfig,
1313
BaseConfig,
14-
CoreMLComputeUnit,
1514
CoreMLConfig,
1615
DebugConfig,
1716
ExportConfig,
@@ -66,6 +65,34 @@ def test_shared_embedding_without_lowbit(self):
6665
with self.assertRaises(ValueError):
6766
LlmConfig(model=model_cfg, quantization=qcfg)
6867

68+
def test_invalid_model_type(self):
69+
with self.assertRaises(ValueError):
70+
BaseConfig(model_class="invalid_model")
71+
72+
def test_invalid_dtype_override(self):
73+
with self.assertRaises(ValueError):
74+
ModelConfig(dtype_override="invalid_dtype")
75+
76+
def test_invalid_preq_mode(self):
77+
with self.assertRaises(ValueError):
78+
BaseConfig(preq_mode="invalid_preq")
79+
80+
def test_invalid_pt2e_quantize(self):
81+
with self.assertRaises(ValueError):
82+
QuantizationConfig(pt2e_quantize="invalid_pt2e")
83+
84+
def test_invalid_spin_quant(self):
85+
with self.assertRaises(ValueError):
86+
QuantizationConfig(use_spin_quant="invalid_spin")
87+
88+
def test_invalid_coreml_quantize(self):
89+
with self.assertRaises(ValueError):
90+
CoreMLConfig(quantize="invalid_quantize")
91+
92+
def test_invalid_coreml_compute_units(self):
93+
with self.assertRaises(ValueError):
94+
CoreMLConfig(compute_units="invalid_compute_units")
95+
6996

7097
class TestValidConstruction(unittest.TestCase):
7198

@@ -94,7 +121,7 @@ def test_valid_llm_config(self):
94121
backend=BackendConfig(
95122
xnnpack=XNNPackConfig(enabled=False),
96123
coreml=CoreMLConfig(
97-
enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL
124+
enabled=True, ios=17, compute_units="all"
98125
),
99126
),
100127
)

0 commit comments

Comments
 (0)