Skip to content

Commit 1b9117c

Browse files
committed
Update
[ghstack-poisoned]
2 parents f31059b + 49d56c4 commit 1b9117c

File tree

5 files changed

+208
-172
lines changed

5 files changed

+208
-172
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 82 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ast
1717
import re
1818
from dataclasses import dataclass, field
19+
from enum import Enum
1920
from typing import ClassVar, List, Optional
2021

2122

@@ -24,27 +25,32 @@
2425
################################################################################
2526

2627

27-
MODEL_TYPE_OPTIONS = [
28-
"stories110m",
29-
"llama2",
30-
"llama3",
31-
"llama3_1",
32-
"llama3_2",
33-
"llama3_2_vision",
34-
"static_llama",
35-
"qwen2_5",
36-
"qwen3-0_6b",
37-
"qwen3-1_7b",
38-
"qwen3-4b",
39-
"phi_4_mini",
40-
"smollm2",
41-
]
28+
class ModelType(str, Enum):
29+
stories110m = "stories110m"
30+
llama2 = "llama2"
31+
llama3 = "llama3"
32+
llama3_1 = "llama3_1"
33+
llama3_2 = "llama3_2"
34+
llama3_2_vision = "llama3_2_vision"
35+
static_llama = "static_llama"
36+
qwen2_5 = "qwen2_5"
37+
qwen3_0_6b = "qwen3-0_6b"
38+
qwen3_1_7b = "qwen3-1_7b"
39+
qwen3_4b = "qwen3-4b"
40+
phi_4_mini = "phi_4_mini"
41+
smollm2 = "smollm2"
4242

4343

44-
PREQ_MODE_OPTIONS = [
45-
"8da4w",
46-
"8da4w_output_8da8w",
47-
]
44+
class PreqMode(str, Enum):
45+
"""
46+
If you are dealing with pre-quantized checkpoints, this used to
47+
be the way to specify them. Now you don't need to specify these
48+
options if you use a TorchAo-prequantized checkpoint, but they
49+
are still around to preserve backward compatibility.
50+
"""
51+
52+
preq_8da4w = "8da4w"
53+
preq_8da4w_out_8da8w = "8da4w_output_8da8w"
4854

4955

5056
@dataclass
@@ -76,36 +82,34 @@ class BaseConfig:
7682
are loaded.
7783
"""
7884

79-
model_class: str = "llama3"
85+
model_class: ModelType = ModelType.llama3
8086
params: Optional[str] = None
8187
checkpoint: Optional[str] = None
8288
checkpoint_dir: Optional[str] = None
8389
tokenizer_path: Optional[str] = None
8490
metadata: Optional[str] = None
8591
use_lora: int = 0
8692
fairseq2: bool = False
87-
preq_mode: Optional[str] = None
93+
preq_mode: Optional[PreqMode] = None
8894
preq_group_size: int = 32
8995
preq_embedding_quantize: str = "8,0"
9096

91-
def __post_init__(self):
92-
if self.model_class not in MODEL_TYPE_OPTIONS:
93-
raise ValueError(f"model_class must be one of {MODEL_TYPE_OPTIONS}, got '{self.model_class}'")
94-
95-
if self.preq_mode is not None and self.preq_mode not in PREQ_MODE_OPTIONS:
96-
raise ValueError(f"preq_mode must be one of {PREQ_MODE_OPTIONS}, got '{self.preq_mode}'")
97-
9897

9998
################################################################################
10099
################################# ModelConfig ##################################
101100
################################################################################
102101

103102

104-
DTYPE_OVERRIDE_OPTIONS = [
105-
"fp32",
106-
"fp16",
107-
"bf16",
108-
]
103+
class DtypeOverride(str, Enum):
104+
"""
105+
DType of the model. Highly recommended to use "fp32", unless you want to
106+
export without a backend, in which case you can also use "bf16". "fp16"
107+
is not recommended.
108+
"""
109+
110+
fp32 = "fp32"
111+
fp16 = "fp16"
112+
bf16 = "bf16"
109113

110114

111115
@dataclass
@@ -143,7 +147,7 @@ class ModelConfig:
143147
[16] pattern specifies all layers have a sliding window of 16.
144148
"""
145149

146-
dtype_override: str = "fp32"
150+
dtype_override: DtypeOverride = DtypeOverride.fp32
147151
enable_dynamic_shape: bool = True
148152
use_shared_embedding: bool = False
149153
use_sdpa_with_kv_cache: bool = False
@@ -156,9 +160,6 @@ class ModelConfig:
156160
local_global_attention: Optional[List[int]] = None
157161

158162
def __post_init__(self):
159-
if self.dtype_override not in DTYPE_OVERRIDE_OPTIONS:
160-
raise ValueError(f"dtype_override must be one of {DTYPE_OVERRIDE_OPTIONS}, got '{self.dtype_override}'")
161-
162163
self._validate_attention_sink()
163164
self._validate_local_global_attention()
164165

@@ -260,25 +261,31 @@ class DebugConfig:
260261
################################################################################
261262

262263

263-
PT2E_QUANTIZE_OPTIONS = [
264-
"xnnpack_dynamic",
265-
"xnnpack_dynamic_qc4",
266-
"qnn_8a8w",
267-
"qnn_16a16w",
268-
"qnn_16a4w",
269-
"coreml_c4w",
270-
"coreml_8a_c8w",
271-
"coreml_8a_c4w",
272-
"coreml_baseline_8a_c8w",
273-
"coreml_baseline_8a_c4w",
274-
"vulkan_8w",
275-
]
264+
class Pt2eQuantize(str, Enum):
265+
"""
266+
Type of backend-specific Pt2e quantization strategy to use.
267+
268+
Pt2e uses a different quantization library that is graph-based
269+
compared to `qmode`, which is also specified in the QuantizationConfig
270+
and is source transform-based.
271+
"""
276272

273+
xnnpack_dynamic = "xnnpack_dynamic"
274+
xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
275+
qnn_8a8w = "qnn_8a8w"
276+
qnn_16a16w = "qnn_16a16w"
277+
qnn_16a4w = "qnn_16a4w"
278+
coreml_c4w = "coreml_c4w"
279+
coreml_8a_c8w = "coreml_8a_c8w"
280+
coreml_8a_c4w = "coreml_8a_c4w"
281+
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
282+
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
283+
vulkan_8w = "vulkan_8w"
277284

278-
SPIN_QUANT_OPTIONS = [
279-
"cuda",
280-
"native",
281-
]
285+
286+
class SpinQuant(str, Enum):
287+
cuda = "cuda"
288+
native = "native"
282289

283290

284291
@dataclass
@@ -313,22 +320,16 @@ class QuantizationConfig:
313320

314321
qmode: Optional[str] = None
315322
embedding_quantize: Optional[str] = None
316-
pt2e_quantize: Optional[str] = None
323+
pt2e_quantize: Optional[Pt2eQuantize] = None
317324
group_size: Optional[int] = None
318-
use_spin_quant: Optional[str] = None
325+
use_spin_quant: Optional[SpinQuant] = None
319326
use_qat: bool = False
320327
calibration_tasks: Optional[List[str]] = None
321328
calibration_limit: Optional[int] = None
322329
calibration_seq_length: Optional[int] = None
323330
calibration_data: str = "Once upon a time"
324331

325332
def __post_init__(self):
326-
if self.pt2e_quantize is not None and self.pt2e_quantize not in PT2E_QUANTIZE_OPTIONS:
327-
raise ValueError(f"pt2e_quantize must be one of {PT2E_QUANTIZE_OPTIONS}, got '{self.pt2e_quantize}'")
328-
329-
if self.use_spin_quant is not None and self.use_spin_quant not in SPIN_QUANT_OPTIONS:
330-
raise ValueError(f"use_spin_quant must be one of {SPIN_QUANT_OPTIONS}, got '{self.use_spin_quant}'")
331-
332333
if self.qmode:
333334
self._validate_qmode()
334335

@@ -376,18 +377,16 @@ class XNNPackConfig:
376377
extended_ops: bool = False
377378

378379

379-
COREML_QUANTIZE_OPTIONS = [
380-
"b4w",
381-
"c4w",
382-
]
380+
class CoreMLQuantize(str, Enum):
381+
b4w = "b4w"
382+
c4w = "c4w"
383383

384384

385-
COREML_COMPUTE_UNIT_OPTIONS = [
386-
"cpu_only",
387-
"cpu_and_gpu",
388-
"cpu_and_ne",
389-
"all",
390-
]
385+
class CoreMLComputeUnit(str, Enum):
386+
cpu_only = "cpu_only"
387+
cpu_and_gpu = "cpu_and_gpu"
388+
cpu_and_ne = "cpu_and_ne"
389+
all = "all"
391390

392391

393392
@dataclass
@@ -399,17 +398,11 @@ class CoreMLConfig:
399398
enabled: bool = False
400399
enable_state: bool = False
401400
preserve_sdpa: bool = False
402-
quantize: Optional[str] = None
401+
quantize: Optional[CoreMLQuantize] = None
403402
ios: int = 15
404-
compute_units: str = "cpu_only"
403+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only
405404

406405
def __post_init__(self):
407-
if self.quantize is not None and self.quantize not in COREML_QUANTIZE_OPTIONS:
408-
raise ValueError(f"quantize must be one of {COREML_QUANTIZE_OPTIONS}, got '{self.quantize}'")
409-
410-
if self.compute_units not in COREML_COMPUTE_UNIT_OPTIONS:
411-
raise ValueError(f"compute_units must be one of {COREML_COMPUTE_UNIT_OPTIONS}, got '{self.compute_units}'")
412-
413406
if self.ios not in (15, 16, 17, 18):
414407
raise ValueError(f"Invalid coreml ios version: {self.ios}")
415408

@@ -488,7 +481,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
488481

489482
# BaseConfig
490483
if hasattr(args, "model"):
491-
llm_config.base.model_class = args.model
484+
llm_config.base.model_class = ModelType(args.model)
492485
if hasattr(args, "params"):
493486
llm_config.base.params = args.params
494487
if hasattr(args, "checkpoint"):
@@ -506,15 +499,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
506499

507500
# PreqMode settings
508501
if hasattr(args, "preq_mode") and args.preq_mode:
509-
llm_config.base.preq_mode = args.preq_mode
502+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
510503
if hasattr(args, "preq_group_size"):
511504
llm_config.base.preq_group_size = args.preq_group_size
512505
if hasattr(args, "preq_embedding_quantize"):
513506
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
514507

515508
# ModelConfig
516509
if hasattr(args, "dtype_override"):
517-
llm_config.model.dtype_override = args.dtype_override
510+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
518511
if hasattr(args, "enable_dynamic_shape"):
519512
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
520513
if hasattr(args, "use_shared_embedding"):
@@ -556,11 +549,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
556549
if hasattr(args, "embedding_quantize"):
557550
llm_config.quantization.embedding_quantize = args.embedding_quantize
558551
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
559-
llm_config.quantization.pt2e_quantize = args.pt2e_quantize
552+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
560553
if hasattr(args, "group_size"):
561554
llm_config.quantization.group_size = args.group_size
562555
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
563-
llm_config.quantization.use_spin_quant = args.use_spin_quant
556+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
564557
if hasattr(args, "use_qat"):
565558
llm_config.quantization.use_qat = args.use_qat
566559
if hasattr(args, "calibration_tasks"):
@@ -588,11 +581,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
588581
args, "coreml_preserve_sdpa", False
589582
)
590583
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
591-
llm_config.backend.coreml.quantize = args.coreml_quantize
584+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
592585
if hasattr(args, "coreml_ios"):
593586
llm_config.backend.coreml.ios = args.coreml_ios
594587
if hasattr(args, "coreml_compute_units"):
595-
llm_config.backend.coreml.compute_units = args.coreml_compute_units
588+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
589+
args.coreml_compute_units
590+
)
596591

597592
# Vulkan
598593
if hasattr(args, "vulkan"):

examples/models/llama/config/test_llm_config.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.examples.models.llama.config.llm_config import (
1212
BackendConfig,
1313
BaseConfig,
14+
CoreMLComputeUnit,
1415
CoreMLConfig,
1516
DebugConfig,
1617
ExportConfig,
@@ -65,34 +66,6 @@ def test_shared_embedding_without_lowbit(self):
6566
with self.assertRaises(ValueError):
6667
LlmConfig(model=model_cfg, quantization=qcfg)
6768

68-
def test_invalid_model_type(self):
69-
with self.assertRaises(ValueError):
70-
BaseConfig(model_class="invalid_model")
71-
72-
def test_invalid_dtype_override(self):
73-
with self.assertRaises(ValueError):
74-
ModelConfig(dtype_override="invalid_dtype")
75-
76-
def test_invalid_preq_mode(self):
77-
with self.assertRaises(ValueError):
78-
BaseConfig(preq_mode="invalid_preq")
79-
80-
def test_invalid_pt2e_quantize(self):
81-
with self.assertRaises(ValueError):
82-
QuantizationConfig(pt2e_quantize="invalid_pt2e")
83-
84-
def test_invalid_spin_quant(self):
85-
with self.assertRaises(ValueError):
86-
QuantizationConfig(use_spin_quant="invalid_spin")
87-
88-
def test_invalid_coreml_quantize(self):
89-
with self.assertRaises(ValueError):
90-
CoreMLConfig(quantize="invalid_quantize")
91-
92-
def test_invalid_coreml_compute_units(self):
93-
with self.assertRaises(ValueError):
94-
CoreMLConfig(compute_units="invalid_compute_units")
95-
9669

9770
class TestValidConstruction(unittest.TestCase):
9871

@@ -121,7 +94,7 @@ def test_valid_llm_config(self):
12194
backend=BackendConfig(
12295
xnnpack=XNNPackConfig(enabled=False),
12396
coreml=CoreMLConfig(
124-
enabled=True, ios=17, compute_units="all"
97+
enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL
12598
),
12699
),
127100
)

0 commit comments

Comments
 (0)