Skip to content

Commit 3b02c99

Browse files
authored
Fix LlmConfig enum usage (#11833)
Fixes some bugs with how enum fields are used.
1 parent ff3c3b6 commit 3b02c99

File tree

4 files changed

+111
-62
lines changed

4 files changed

+111
-62
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@
2626

2727

2828
class ModelType(str, Enum):
29-
STORIES110M = "stories110m"
30-
LLAMA2 = "llama2"
31-
LLAMA3 = "llama3"
32-
LLAMA3_1 = "llama3_1"
33-
LLAMA3_2 = "llama3_2"
34-
LLAMA3_2_VISION = "llama3_2_vision"
35-
STATIC_LLAMA = "static_llama"
36-
QWEN2_5 = "qwen2_5"
37-
QWEN3_0_6B = "qwen3-0_6b"
38-
QWEN3_1_7B = "qwen3-1_7b"
39-
QWEN3_4B = "qwen3-4b"
40-
PHI_4_MINI = "phi_4_mini"
41-
SMOLLM2 = "smollm2"
29+
stories110m = "stories110m"
30+
llama2 = "llama2"
31+
llama3 = "llama3"
32+
llama3_1 = "llama3_1"
33+
llama3_2 = "llama3_2"
34+
llama3_2_vision = "llama3_2_vision"
35+
static_llama = "static_llama"
36+
qwen2_5 = "qwen2_5"
37+
qwen3_0_6b = "qwen3-0_6b"
38+
qwen3_1_7b = "qwen3-1_7b"
39+
qwen3_4b = "qwen3-4b"
40+
phi_4_mini = "phi_4_mini"
41+
smollm2 = "smollm2"
4242

4343

4444
class PreqMode(str, Enum):
@@ -49,8 +49,8 @@ class PreqMode(str, Enum):
4949
are still around to preserve backward compatibility.
5050
"""
5151

52-
PREQ_8DA4W = "8da4w"
53-
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
52+
preq_8da4w = "8da4w"
53+
preq_8da4w_out_8da8w = "8da4w_output_8da8w"
5454

5555

5656
@dataclass
@@ -82,7 +82,7 @@ class BaseConfig:
8282
are loaded.
8383
"""
8484

85-
model_class: ModelType = ModelType.LLAMA3
85+
model_class: ModelType = ModelType.llama3
8686
params: Optional[str] = None
8787
checkpoint: Optional[str] = None
8888
checkpoint_dir: Optional[str] = None
@@ -107,9 +107,9 @@ class DtypeOverride(str, Enum):
107107
is not recommended.
108108
"""
109109

110-
FP32 = "fp32"
111-
FP16 = "fp16"
112-
BF16 = "bf16"
110+
fp32 = "fp32"
111+
fp16 = "fp16"
112+
bf16 = "bf16"
113113

114114

115115
@dataclass
@@ -147,7 +147,7 @@ class ModelConfig:
147147
[16] pattern specifies all layers have a sliding window of 16.
148148
"""
149149

150-
dtype_override: DtypeOverride = DtypeOverride.FP32
150+
dtype_override: DtypeOverride = DtypeOverride.fp32
151151
enable_dynamic_shape: bool = True
152152
use_shared_embedding: bool = False
153153
use_sdpa_with_kv_cache: bool = False
@@ -270,22 +270,22 @@ class Pt2eQuantize(str, Enum):
270270
and is source transform-based.
271271
"""
272272

273-
XNNPACK_DYNAMIC = "xnnpack_dynamic"
274-
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
275-
QNN_8A8W = "qnn_8a8w"
276-
QNN_16A16W = "qnn_16a16w"
277-
QNN_16A4W = "qnn_16a4w"
278-
COREML_C4W = "coreml_c4w"
279-
COREML_8A_C8W = "coreml_8a_c8w"
280-
COREML_8A_C4W = "coreml_8a_c4w"
281-
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
282-
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
283-
VULKAN_8W = "vulkan_8w"
273+
xnnpack_dynamic = "xnnpack_dynamic"
274+
xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
275+
qnn_8a8w = "qnn_8a8w"
276+
qnn_16a16w = "qnn_16a16w"
277+
qnn_16a4w = "qnn_16a4w"
278+
coreml_c4w = "coreml_c4w"
279+
coreml_8a_c8w = "coreml_8a_c8w"
280+
coreml_8a_c4w = "coreml_8a_c4w"
281+
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
282+
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
283+
vulkan_8w = "vulkan_8w"
284284

285285

286286
class SpinQuant(str, Enum):
287-
CUDA = "cuda"
288-
NATIVE = "native"
287+
cuda = "cuda"
288+
native = "native"
289289

290290

291291
@dataclass
@@ -378,15 +378,15 @@ class XNNPackConfig:
378378

379379

380380
class CoreMLQuantize(str, Enum):
381-
B4W = "b4w"
382-
C4W = "c4w"
381+
b4w = "b4w"
382+
c4w = "c4w"
383383

384384

385385
class CoreMLComputeUnit(str, Enum):
386-
CPU_ONLY = "cpu_only"
387-
CPU_AND_GPU = "cpu_and_gpu"
388-
CPU_AND_NE = "cpu_and_ne"
389-
ALL = "all"
386+
cpu_only = "cpu_only"
387+
cpu_and_gpu = "cpu_and_gpu"
388+
cpu_and_ne = "cpu_and_ne"
389+
all = "all"
390390

391391

392392
@dataclass
@@ -400,7 +400,7 @@ class CoreMLConfig:
400400
preserve_sdpa: bool = False
401401
quantize: Optional[CoreMLQuantize] = None
402402
ios: int = 15
403-
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
403+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only
404404

405405
def __post_init__(self):
406406
if self.ios not in (15, 16, 17, 18):

examples/models/llama/export_llama_lib.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def export_llama(
590590

591591
# If a checkpoint isn't provided for an HF OSS model, download and convert the
592592
# weights first.
593-
model_name = llm_config.base.model_class
593+
model_name = llm_config.base.model_class.value
594594
if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS:
595595
repo_id = HUGGING_FACE_REPO_IDS[model_name]
596596
if model_name == "qwen2_5":
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668
llm_config.export.output_dir = output_dir_path
669669

670670
# Convert dtype override string to actual type.
671-
dtype_override = DType[llm_config.model.dtype_override]
671+
dtype_override = DType[llm_config.model.dtype_override.value]
672672

673673
edge_manager = _load_llama_model(llm_config)
674674

@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702
checkpoint=llm_config.base.checkpoint,
703703
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
704704
tokenizer_path=llm_config.base.tokenizer_path,
705-
use_spin_quant=llm_config.quantization.use_spin_quant,
705+
use_spin_quant=(
706+
llm_config.quantization.use_spin_quant.value
707+
if llm_config.quantization.use_spin_quant
708+
else None
709+
),
706710
embedding_quantize=llm_config.quantization.embedding_quantize,
707711
use_shared_embedding=llm_config.model.use_shared_embedding,
708712
quantization_mode=llm_config.quantization.qmode,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730
vulkan=llm_config.backend.vulkan.enabled,
727731
use_qat=llm_config.quantization.use_qat,
728732
use_lora=llm_config.base.use_lora,
729-
preq_mode=llm_config.base.preq_mode,
733+
preq_mode=(
734+
llm_config.base.preq_mode.value if llm_config.base.preq_mode else None
735+
),
730736
preq_group_size=llm_config.base.preq_group_size,
731737
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
732738
local_global_attention=llm_config.model.local_global_attention,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744

739745
def get_quantizer_and_quant_params(llm_config):
740746
pt2e_quant_params = get_pt2e_quantization_params(
741-
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
747+
(
748+
llm_config.quantization.pt2e_quantize.value
749+
if llm_config.quantization.pt2e_quantize
750+
else None
751+
),
752+
llm_config.quantization.qmode,
742753
)
743754
quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library)
744755
quant_dtype = None
745756
if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize:
746757
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
747758
qnn_quantizer, quant_dtype = get_qnn_quantizer(
748-
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
759+
llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode
749760
)
750761
quantizers.append(qnn_quantizer)
751762
if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize:
752763
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
753-
coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize)
764+
coreml_quantizer = get_coreml_quantizer(
765+
llm_config.quantization.pt2e_quantize.value
766+
)
754767
quantizers.append(coreml_quantizer)
755768
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
756769
assert (
757770
len(quantizers) == 0
758771
), "Should not enable both vulkan and other quantizers"
759-
vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize)
772+
vulkan_quantizer = get_vulkan_quantizer(
773+
llm_config.quantization.pt2e_quantize.value
774+
)
760775
quantizers.append(vulkan_quantizer)
761776
logging.info(f"Applying quantizers: {quantizers}")
762777
return pt2e_quant_params, quantizers, quant_dtype
@@ -1035,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10351050
)
10361051

10371052
additional_passes = []
1038-
if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS:
1053+
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
10391054
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
10401055

10411056
# export_to_edge
@@ -1074,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10741089
mps=llm_config.backend.mps.enabled,
10751090
coreml=llm_config.backend.coreml.enabled,
10761091
qnn=llm_config.backend.qnn.enabled,
1077-
dtype_override=llm_config.model.dtype_override,
1092+
dtype_override=llm_config.model.dtype_override.value,
10781093
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
10791094
use_kv_cache=llm_config.model.use_kv_cache,
10801095
embedding_quantize=llm_config.quantization.embedding_quantize,
1081-
pt2e_quantize=llm_config.quantization.pt2e_quantize,
1096+
pt2e_quantize=(
1097+
llm_config.quantization.pt2e_quantize.value
1098+
if llm_config.quantization.pt2e_quantize
1099+
else None
1100+
),
10821101
coreml_ios=llm_config.backend.coreml.ios,
1083-
coreml_quantize=llm_config.backend.coreml.quantize,
1084-
coreml_compute_units=llm_config.backend.coreml.compute_units,
1102+
coreml_quantize=(
1103+
llm_config.backend.coreml.quantize.value
1104+
if llm_config.backend.coreml.quantize
1105+
else None
1106+
),
1107+
coreml_compute_units=llm_config.backend.coreml.compute_units.value,
10851108
use_qnn_sha=llm_config.backend.qnn.use_sha,
10861109
num_sharding=llm_config.backend.qnn.num_sharding,
10871110
soc_model=llm_config.backend.qnn.soc_model,
@@ -1154,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11541177
An instance of LLMEdgeManager which contains the eager mode model.
11551178
"""
11561179

1157-
modelname = llm_config.base.model_class
1180+
modelname = llm_config.base.model_class.value
11581181
if modelname in EXECUTORCH_DEFINED_MODELS:
11591182
module_name = "llama"
11601183
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1175,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11751198
)
11761199
)
11771200
# Convert dtype override string to actual type.
1178-
dtype_override = DType[llm_config.model.dtype_override]
1201+
dtype_override = DType[llm_config.model.dtype_override.value]
11791202

11801203
return LLMEdgeManager(
11811204
model=model,

examples/models/llama/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
157157

158158
if model_args.use_scaled_rope:
159159
# Older models don't have use_scaled_rope configuration
160-
model_name = str(self.llm_config.base.model_class)
160+
model_name = self.llm_config.base.model_class.value
161161
assert model_name not in ["llama2", "stories110m"]
162162

163163
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
@@ -328,10 +328,10 @@ def get_example_inputs_kvcache_sdpa(self):
328328

329329
def _transform_for_pre_quantization(self, checkpoint, model_args):
330330
assert self.llm_config.base.preq_mode, "preq_mode must be specified"
331-
assert self.llm_config.base.preq_mode in [
331+
assert self.llm_config.base.preq_mode.value in [
332332
"8da4w",
333333
"8da4w_output_8da8w",
334-
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
334+
], f"Quantization mode {self.llm_config.base.preq_mode.value} is not compatible with SpinQuant."
335335
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
336336
assert self.llm_config.model.dtype_override, "dtype_override must be specified"
337337

@@ -351,22 +351,22 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
351351
}
352352

353353
# Transform the output layer first if needed.
354-
if self.llm_config.base.preq_mode == "8da4w_output_8da8w":
354+
if self.llm_config.base.preq_mode.value == "8da4w_output_8da8w":
355355
from .source_transformation.pre_quantization import (
356356
transform_output_linear_for_pre_quantization,
357357
)
358358

359359
self.model_ = transform_output_linear_for_pre_quantization(
360360
module=self.model_,
361361
checkpoint=checkpoint,
362-
dtype=mapping[self.llm_config.model.dtype_override],
362+
dtype=mapping[self.llm_config.model.dtype_override.value],
363363
)
364364

365365
self.model_ = transform_linear_for_pre_quantization(
366366
self.model_,
367367
checkpoint,
368368
self.llm_config.base.preq_group_size,
369-
mapping[self.llm_config.model.dtype_override],
369+
mapping[self.llm_config.model.dtype_override.value],
370370
)
371371

372372
embedding_bit_width, embedding_group_size = None, None
@@ -390,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
390390
self.model_ = transform_embedding_for_pre_quantization(
391391
self.model_,
392392
checkpoint,
393-
mapping[self.llm_config.model.dtype_override],
393+
mapping[self.llm_config.model.dtype_override.value],
394394
int(embedding_bit_width),
395395
embedding_group_size,
396396
)

extension/llm/export/test/test_export_llm.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
5151
f.write(
5252
"""
5353
base:
54+
model_class: llama2
5455
tokenizer_path: /path/to/tokenizer.json
56+
preq_mode: preq_8da4w
57+
model:
58+
dtype_override: fp16
5559
export:
5660
max_seq_length: 256
61+
quantization:
62+
pt2e_quantize: xnnpack_dynamic
63+
use_spin_quant: cuda
64+
backend:
65+
coreml:
66+
quantize: c4w
67+
compute_units: cpu_and_gpu
5768
"""
5869
)
5970
config_file = f.name
@@ -69,7 +80,22 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
6980
self.assertEqual(
7081
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
7182
)
83+
self.assertEqual(called_config["base"]["model_class"], "llama2")
84+
self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w")
85+
self.assertEqual(called_config["model"]["dtype_override"].value, "fp16")
7286
self.assertEqual(called_config["export"]["max_seq_length"], 256)
87+
self.assertEqual(
88+
called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic"
89+
)
90+
self.assertEqual(
91+
called_config["quantization"]["use_spin_quant"].value, "cuda"
92+
)
93+
self.assertEqual(
94+
called_config["backend"]["coreml"]["quantize"].value, "c4w"
95+
)
96+
self.assertEqual(
97+
called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu"
98+
)
7399
finally:
74100
os.unlink(config_file)
75101

0 commit comments

Comments
 (0)