Skip to content

Commit dc33c3b

Browse files
committed
Fix LlmConfig enum usage
ghstack-source-id: 6416111 ghstack-comment-id: 2992966686 Pull-Request: #11833
1 parent e18f55b commit dc33c3b

File tree

4 files changed

+80
-62
lines changed

4 files changed

+80
-62
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@
2626

2727

2828
class ModelType(str, Enum):
29-
STORIES110M = "stories110m"
30-
LLAMA2 = "llama2"
31-
LLAMA3 = "llama3"
32-
LLAMA3_1 = "llama3_1"
33-
LLAMA3_2 = "llama3_2"
34-
LLAMA3_2_VISION = "llama3_2_vision"
35-
STATIC_LLAMA = "static_llama"
36-
QWEN2_5 = "qwen2_5"
37-
QWEN3_0_6B = "qwen3-0_6b"
38-
QWEN3_1_7B = "qwen3-1_7b"
39-
QWEN3_4B = "qwen3-4b"
40-
PHI_4_MINI = "phi_4_mini"
41-
SMOLLM2 = "smollm2"
29+
stories110m = "stories110m"
30+
llama2 = "llama2"
31+
llama3 = "llama3"
32+
llama3_1 = "llama3_1"
33+
llama3_2 = "llama3_2"
34+
llama3_2_vision = "llama3_2_vision"
35+
static_llama = "static_llama"
36+
qwen2_5 = "qwen2_5"
37+
qwen3_0_6b = "qwen3-0_6b"
38+
qwen3_1_7b = "qwen3-1_7b"
39+
qwen3_4b = "qwen3-4b"
40+
phi_4_mini = "phi_4_mini"
41+
smollm2 = "smollm2"
4242

4343

4444
class PreqMode(str, Enum):
@@ -49,8 +49,8 @@ class PreqMode(str, Enum):
4949
are still around to preserve backward compatibility.
5050
"""
5151

52-
PREQ_8DA4W = "8da4w"
53-
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
52+
preq_8da4w = "8da4w"
53+
preq_8da4w_out_8da8w = "8da4w_output_8da8w"
5454

5555

5656
@dataclass
@@ -82,7 +82,7 @@ class BaseConfig:
8282
are loaded.
8383
"""
8484

85-
model_class: ModelType = ModelType.LLAMA3
85+
model_class: ModelType = ModelType.llama3
8686
params: Optional[str] = None
8787
checkpoint: Optional[str] = None
8888
checkpoint_dir: Optional[str] = None
@@ -107,9 +107,9 @@ class DtypeOverride(str, Enum):
107107
is not recommended.
108108
"""
109109

110-
FP32 = "fp32"
111-
FP16 = "fp16"
112-
BF16 = "bf16"
110+
fp32 = "fp32"
111+
fp16 = "fp16"
112+
bf16 = "bf16"
113113

114114

115115
@dataclass
@@ -147,7 +147,7 @@ class ModelConfig:
147147
[16] pattern specifies all layers have a sliding window of 16.
148148
"""
149149

150-
dtype_override: DtypeOverride = DtypeOverride.FP32
150+
dtype_override: DtypeOverride = DtypeOverride.fp32
151151
enable_dynamic_shape: bool = True
152152
use_shared_embedding: bool = False
153153
use_sdpa_with_kv_cache: bool = False
@@ -270,22 +270,22 @@ class Pt2eQuantize(str, Enum):
270270
and is source transform-based.
271271
"""
272272

273-
XNNPACK_DYNAMIC = "xnnpack_dynamic"
274-
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
275-
QNN_8A8W = "qnn_8a8w"
276-
QNN_16A16W = "qnn_16a16w"
277-
QNN_16A4W = "qnn_16a4w"
278-
COREML_C4W = "coreml_c4w"
279-
COREML_8A_C8W = "coreml_8a_c8w"
280-
COREML_8A_C4W = "coreml_8a_c4w"
281-
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
282-
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
283-
VULKAN_8W = "vulkan_8w"
273+
xnnpack_dynamic = "xnnpack_dynamic"
274+
xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
275+
qnn_8a8w = "qnn_8a8w"
276+
qnn_16a16w = "qnn_16a16w"
277+
qnn_16a4w = "qnn_16a4w"
278+
coreml_c4w = "coreml_c4w"
279+
coreml_8a_c8w = "coreml_8a_c8w"
280+
coreml_8a_c4w = "coreml_8a_c4w"
281+
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
282+
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
283+
vulkan_8w = "vulkan_8w"
284284

285285

286286
class SpinQuant(str, Enum):
287-
CUDA = "cuda"
288-
NATIVE = "native"
287+
cuda = "cuda"
288+
native = "native"
289289

290290

291291
@dataclass
@@ -378,15 +378,15 @@ class XNNPackConfig:
378378

379379

380380
class CoreMLQuantize(str, Enum):
381-
B4W = "b4w"
382-
C4W = "c4w"
381+
b4w = "b4w"
382+
c4w = "c4w"
383383

384384

385385
class CoreMLComputeUnit(str, Enum):
386-
CPU_ONLY = "cpu_only"
387-
CPU_AND_GPU = "cpu_and_gpu"
388-
CPU_AND_NE = "cpu_and_ne"
389-
ALL = "all"
386+
cpu_only = "cpu_only"
387+
cpu_and_gpu = "cpu_and_gpu"
388+
cpu_and_ne = "cpu_and_ne"
389+
all = "all"
390390

391391

392392
@dataclass
@@ -400,7 +400,7 @@ class CoreMLConfig:
400400
preserve_sdpa: bool = False
401401
quantize: Optional[CoreMLQuantize] = None
402402
ios: int = 15
403-
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
403+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only
404404

405405
def __post_init__(self):
406406
if self.ios not in (15, 16, 17, 18):

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def export_llama(
590590

591591
# If a checkpoint isn't provided for an HF OSS model, download and convert the
592592
# weights first.
593-
model_name = llm_config.base.model_class
593+
model_name = llm_config.base.model_class.value
594594
if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS:
595595
repo_id = HUGGING_FACE_REPO_IDS[model_name]
596596
if model_name == "qwen2_5":
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668
llm_config.export.output_dir = output_dir_path
669669

670670
# Convert dtype override string to actual type.
671-
dtype_override = DType[llm_config.model.dtype_override]
671+
dtype_override = DType[llm_config.model.dtype_override.value]
672672

673673
edge_manager = _load_llama_model(llm_config)
674674

@@ -702,7 +702,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702
checkpoint=llm_config.base.checkpoint,
703703
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
704704
tokenizer_path=llm_config.base.tokenizer_path,
705-
use_spin_quant=llm_config.quantization.use_spin_quant,
705+
use_spin_quant=llm_config.quantization.use_spin_quant.value if llm_config.quantization.use_spin_quant else None,
706706
embedding_quantize=llm_config.quantization.embedding_quantize,
707707
use_shared_embedding=llm_config.model.use_shared_embedding,
708708
quantization_mode=llm_config.quantization.qmode,
@@ -726,7 +726,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726726
vulkan=llm_config.backend.vulkan.enabled,
727727
use_qat=llm_config.quantization.use_qat,
728728
use_lora=llm_config.base.use_lora,
729-
preq_mode=llm_config.base.preq_mode,
729+
preq_mode=llm_config.base.preq_mode.value if llm_config.base.preq_mode else None,
730730
preq_group_size=llm_config.base.preq_group_size,
731731
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
732732
local_global_attention=llm_config.model.local_global_attention,
@@ -738,25 +738,25 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738738

739739
def get_quantizer_and_quant_params(llm_config):
740740
pt2e_quant_params = get_pt2e_quantization_params(
741-
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
741+
llm_config.quantization.pt2e_quantize.value if llm_config.quantization.pt2e_quantize else None, llm_config.quantization.qmode
742742
)
743743
quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library)
744744
quant_dtype = None
745745
if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize:
746746
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
747747
qnn_quantizer, quant_dtype = get_qnn_quantizer(
748-
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
748+
llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode
749749
)
750750
quantizers.append(qnn_quantizer)
751751
if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize:
752752
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
753-
coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize)
753+
coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize.value)
754754
quantizers.append(coreml_quantizer)
755755
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
756756
assert (
757757
len(quantizers) == 0
758758
), "Should not enable both vulkan and other quantizers"
759-
vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize)
759+
vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize.value)
760760
quantizers.append(vulkan_quantizer)
761761
logging.info(f"Applying quantizers: {quantizers}")
762762
return pt2e_quant_params, quantizers, quant_dtype
@@ -1033,7 +1033,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10331033
)
10341034

10351035
additional_passes = []
1036-
if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS:
1036+
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
10371037
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
10381038

10391039
# export_to_edge
@@ -1072,14 +1072,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10721072
mps=llm_config.backend.mps.enabled,
10731073
coreml=llm_config.backend.coreml.enabled,
10741074
qnn=llm_config.backend.qnn.enabled,
1075-
dtype_override=llm_config.model.dtype_override,
1075+
dtype_override=llm_config.model.dtype_override.value,
10761076
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
10771077
use_kv_cache=llm_config.model.use_kv_cache,
10781078
embedding_quantize=llm_config.quantization.embedding_quantize,
1079-
pt2e_quantize=llm_config.quantization.pt2e_quantize,
1079+
pt2e_quantize=llm_config.quantization.pt2e_quantize.value if llm_config.quantization.pt2e_quantize else None,
10801080
coreml_ios=llm_config.backend.coreml.ios,
1081-
coreml_quantize=llm_config.backend.coreml.quantize,
1082-
coreml_compute_units=llm_config.backend.coreml.compute_units,
1081+
coreml_quantize=llm_config.backend.coreml.quantize.value if llm_config.backend.coreml.quantize else None,
1082+
coreml_compute_units=llm_config.backend.coreml.compute_units.value,
10831083
use_qnn_sha=llm_config.backend.qnn.use_sha,
10841084
num_sharding=llm_config.backend.qnn.num_sharding,
10851085
soc_model=llm_config.backend.qnn.soc_model,
@@ -1152,7 +1152,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11521152
An instance of LLMEdgeManager which contains the eager mode model.
11531153
"""
11541154

1155-
modelname = llm_config.base.model_class
1155+
modelname = llm_config.base.model_class.value
11561156
if modelname in EXECUTORCH_DEFINED_MODELS:
11571157
module_name = "llama"
11581158
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1173,7 +1173,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11731173
)
11741174
)
11751175
# Convert dtype override string to actual type.
1176-
dtype_override = DType[llm_config.model.dtype_override]
1176+
dtype_override = DType[llm_config.model.dtype_override.value]
11771177

11781178
return LLMEdgeManager(
11791179
model=model,

examples/models/llama/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
157157

158158
if model_args.use_scaled_rope:
159159
# Older models don't have use_scaled_rope configuration
160-
model_name = str(self.llm_config.base.model_class)
160+
model_name = self.llm_config.base.model_class.value
161161
assert model_name not in ["llama2", "stories110m"]
162162

163163
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
@@ -328,10 +328,10 @@ def get_example_inputs_kvcache_sdpa(self):
328328

329329
def _transform_for_pre_quantization(self, checkpoint, model_args):
330330
assert self.llm_config.base.preq_mode, "preq_mode must be specified"
331-
assert self.llm_config.base.preq_mode in [
331+
assert self.llm_config.base.preq_mode.value in [
332332
"8da4w",
333333
"8da4w_output_8da8w",
334-
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
334+
], f"Quantization mode {self.llm_config.base.preq_mode.value} is not compatible with SpinQuant."
335335
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
336336
assert self.llm_config.model.dtype_override, "dtype_override must be specified"
337337

@@ -351,22 +351,22 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
351351
}
352352

353353
# Transform the output layer first if needed.
354-
if self.llm_config.base.preq_mode == "8da4w_output_8da8w":
354+
if self.llm_config.base.preq_mode.value == "8da4w_output_8da8w":
355355
from .source_transformation.pre_quantization import (
356356
transform_output_linear_for_pre_quantization,
357357
)
358358

359359
self.model_ = transform_output_linear_for_pre_quantization(
360360
module=self.model_,
361361
checkpoint=checkpoint,
362-
dtype=mapping[self.llm_config.model.dtype_override],
362+
dtype=mapping[self.llm_config.model.dtype_override.value],
363363
)
364364

365365
self.model_ = transform_linear_for_pre_quantization(
366366
self.model_,
367367
checkpoint,
368368
self.llm_config.base.preq_group_size,
369-
mapping[self.llm_config.model.dtype_override],
369+
mapping[self.llm_config.model.dtype_override.value],
370370
)
371371

372372
embedding_bit_width, embedding_group_size = None, None
@@ -390,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
390390
self.model_ = transform_embedding_for_pre_quantization(
391391
self.model_,
392392
checkpoint,
393-
mapping[self.llm_config.model.dtype_override],
393+
mapping[self.llm_config.model.dtype_override.value],
394394
int(embedding_bit_width),
395395
embedding_group_size,
396396
)

extension/llm/export/test/test_export_llm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
4747
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
4848
f.write("""
4949
base:
50+
model_class: llama2
5051
tokenizer_path: /path/to/tokenizer.json
52+
preq_mode: preq_8da4w
53+
model:
54+
dtype_override: fp16
5155
export:
5256
max_seq_length: 256
57+
quantization:
58+
pt2e_quantize: xnnpack_dynamic
59+
use_spin_quant: cuda
60+
backend:
61+
coreml:
62+
quantize: c4w
63+
compute_units: cpu_and_gpu
5364
""")
5465
config_file = f.name
5566

@@ -62,7 +73,14 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
6273
mock_export_llama.assert_called_once()
6374
called_config = mock_export_llama.call_args[0][0]
6475
self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json")
76+
self.assertEqual(called_config["base"]["model_class"], "llama2")
77+
self.assertEqual(called_config["base"]["preq_mode"], "preq_8da4w")
78+
self.assertEqual(called_config["model"]["dtype_override"], "fp16")
6579
self.assertEqual(called_config["export"]["max_seq_length"], 256)
80+
self.assertEqual(called_config["quantization"]["pt2e_quantize"], "xnnpack_dynamic")
81+
self.assertEqual(called_config["quantization"]["use_spin_quant"], "cuda")
82+
self.assertEqual(called_config["backend"]["coreml"]["quantize"], "c4w")
83+
self.assertEqual(called_config["backend"]["coreml"]["compute_units"], "cpu_and_gpu")
6684
finally:
6785
os.unlink(config_file)
6886

0 commit comments

Comments
 (0)