Skip to content

Commit c8c58db

Browse files
committed
Add new export LLM config
Pull Request resolved: #11028 @imported-using-ghimport Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991/) ghstack-source-id: 288422344
1 parent aa67f08 commit c8c58db

File tree

3 files changed

+336
-0
lines changed

3 files changed

+336
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
"""
10+
Configurations for exporting Llama.
11+
12+
Uses dataclases, which integrate with OmegaConf and Hydra.
13+
"""
14+
15+
import re
16+
from dataclasses import dataclass, field
17+
from enum import Enum
18+
from typing import List, Optional
19+
20+
21+
################################################################################
22+
################################## BaseConfig ##################################
23+
################################################################################
24+
25+
26+
class ModelType(str, Enum):
27+
STORIES110M = "stories110m"
28+
LLAMA2 = "llama2"
29+
LLAMA3 = "llama3"
30+
LLAMA3_1 = "llama3_1"
31+
LLAMA3_2 = "llama3_2"
32+
LLAMA3_2_VISION = "llama3_2_vision"
33+
STATIC_LLAMA = "static_llama"
34+
QWEN2_5 = "qwen2_5"
35+
QWEN3_0_6B = "qwen3-0_6b"
36+
QWEN3_1_7B = "qwen3-1_7b"
37+
QWEN3_4B = "qwen3-4b"
38+
PHI_4_MINI = "phi_4_mini"
39+
SMOLLM2 = "smollm2"
40+
41+
42+
class PreqMode(str, Enum):
43+
"""
44+
If you are dealing with pre-quantized checkpoints, this used to
45+
be the way to specify them. Now you don't need to specify these
46+
options if you use a TorchAo-prequantized checkpoint, but they
47+
are still around to preservce backward compatibility.
48+
"""
49+
50+
PREQ_8DA4W = "8da4w"
51+
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
52+
53+
54+
@dataclass
55+
class BaseConfig:
56+
"""
57+
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
58+
and are the minimal set of parameters needed to load the pretrained
59+
eager model and its weights.
60+
"""
61+
62+
model_class: ModelType = ModelType.LLAMA3
63+
params: Optional[str] = None
64+
checkpoint: Optional[str] = None
65+
checkpoint_dir: Optional[str] = None # For sharded checkpoint.
66+
tokenizer_path: Optional[str] = None
67+
metadata: Optional[str] = None
68+
use_lora: bool = False
69+
fairseq2: bool = False # For legacy internal use cases.
70+
71+
# Legacy pre-quantization options that happen during model weight loading.
72+
preq_mode: Optional[PreqMode] = None
73+
preq_group_size: int = 32
74+
preq_embedding_quantize: str = "8,0"
75+
76+
77+
################################################################################
78+
################################# ModelConfig ##################################
79+
################################################################################
80+
81+
82+
class DtypeOverride(str, Enum):
83+
"""
84+
DType of the model. Highly recommended to use "fp32", unless you want to
85+
export without a backend, in which case you can also use "bf16". "fp16"
86+
is not recommended.
87+
"""
88+
89+
FP32 = "fp32"
90+
FP16 = "fp16"
91+
BF16 = "bf16"
92+
93+
94+
@dataclass
95+
class ModelConfig:
96+
"""
97+
Configurations not necessarily specific to the model, but are needed to
98+
finish off the rest of the model configuration in eager. You can think
99+
of these like optimizations / actual configurations. The same ModelConfig
100+
can be applied to multiple models.
101+
"""
102+
103+
dtype_override: DtypeOverride = DtypeOverride.FP32
104+
enable_dynamic_shape: bool = True
105+
use_shared_embedding: bool = False
106+
use_sdpa_with_kv_cache: bool = False
107+
expand_rope_table: bool = False
108+
use_attention_sink: Optional[str] = None
109+
output_prune_map: Optional[str] = None
110+
input_prune_map: Optional[str] = None
111+
112+
# Below are config options relating to kv cache.
113+
use_kv_cache: bool = False
114+
quantize_kv_cache: bool = False
115+
local_global_attention: Optional[List[int]] = None
116+
117+
118+
################################################################################
119+
################################ ExportConfig ##################################
120+
################################################################################
121+
122+
123+
@dataclass
124+
class ExportConfig:
125+
"""
126+
Configures properties relevant to the export process.
127+
"""
128+
129+
max_seq_length: int = 128
130+
max_context_length: int = 128
131+
output_dir: Optional[str] = None
132+
output_name: Optional[str] = None
133+
so_library: Optional[str] = None
134+
export_only: bool = False
135+
136+
137+
################################################################################
138+
################################# DebugConfig ##################################
139+
################################################################################
140+
141+
142+
@dataclass
143+
class DebugConfig:
144+
"""
145+
Configures options to debug the export process.
146+
"""
147+
148+
profile_memory: bool = False
149+
profile_path: Optional[str] = None
150+
generate_etrecord: bool = False
151+
generate_full_logits: bool = False
152+
verbose: bool = False
153+
154+
155+
################################################################################
156+
############################# QuantizationConfig ###############################
157+
################################################################################
158+
159+
160+
class Pt2eQuantize(str, Enum):
161+
"""
162+
Type of backend-specific Pt2e quantization strategy to use.
163+
164+
Pt2e uses a different quantization library that is graph-based
165+
compared to `qmode`, which is also specified in the QuantizationConfig
166+
and is source transform-based.
167+
"""
168+
169+
XNNPACK_DYNAMIC = "xnnpack_dynamic"
170+
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
171+
QNN_8A8W = "qnn_8a8w"
172+
QNN_16A16W = "qnn_16a16w"
173+
QNN_16A4W = "qnn_16a4w"
174+
COREML_C4W = "coreml_c4w"
175+
COREML_8A_C8W = "coreml_8a_c8w"
176+
COREML_8A_C4W = "coreml_8a_c4w"
177+
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
178+
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
179+
VULKAN_8W = "vulkan_8w"
180+
181+
182+
class SpinQuant(str, Enum):
183+
CUDA = "cuda"
184+
NATIVE = "native"
185+
186+
187+
@dataclass
188+
class QuantizationConfig:
189+
"""
190+
Configures how the model should be quantized (PTQ).
191+
"""
192+
193+
qmode: Optional[str] = None
194+
embedding_quantize: Optional[str] = None
195+
pt2e_quantize: Optional[Pt2eQuantize] = None
196+
group_size: Optional[int] = None
197+
use_spin_quant: Optional[SpinQuant] = None
198+
use_qat: bool = False
199+
calibration_tasks: Optional[List[str]] = None
200+
calibration_limit: Optional[int] = None
201+
calibration_seq_length: Optional[int] = None
202+
calibration_data: str = "Once upon a time"
203+
204+
def __post_init__(self):
205+
if self.qmode:
206+
self._validate_qmode()
207+
208+
def _validate_qmode(self) -> None:
209+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
210+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
211+
212+
if self.qmode in choices:
213+
return
214+
215+
for pattern in patterns:
216+
matches = re.findall(pattern, self.qmode)
217+
if len(matches) == 1:
218+
return
219+
220+
raise ValueError(
221+
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}."
222+
)
223+
224+
225+
################################################################################
226+
############################### BackendConfig ##################################
227+
################################################################################
228+
229+
230+
@dataclass
231+
class XNNPackConfig:
232+
enabled: bool = False
233+
extended_ops: bool = False
234+
235+
236+
class CoreMLQuantize(str, Enum):
237+
B4W = "b4w"
238+
C4W = "c4w"
239+
240+
241+
class CoreMLComputeUnit(str, Enum):
242+
CPU_ONLY = "cpu_only"
243+
CPU_AND_GPU = "cpu_and_gpu"
244+
CPU_AND_NE = "cpu_and_ne"
245+
ALL = "all"
246+
247+
248+
@dataclass
249+
class CoreMLConfig:
250+
enabled: bool = False
251+
enable_state: bool = False
252+
preserve_sdpa: bool = False
253+
quantize: Optional[CoreMLQuantize] = None
254+
ios: int = 15
255+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
256+
257+
def __post_init__(self):
258+
if self.ios not in (15, 16, 17, 18):
259+
raise ValueError(f"Invalid coreml ios version: {self.ios}")
260+
261+
262+
@dataclass
263+
class VulkanConfig:
264+
enabled: bool = False
265+
266+
267+
@dataclass
268+
class QNNConfig:
269+
enabled: bool = False
270+
use_sha: bool = False
271+
soc_model: str = "SM8650"
272+
use_qnn_sha: bool = False
273+
optimized_rotation_path: Optional[str] = None
274+
num_sharding: int = 0
275+
276+
277+
@dataclass
278+
class MPSConfig:
279+
enabled: bool = False
280+
281+
282+
@dataclass
283+
class BackendConfig:
284+
"""
285+
Configures which backends should be used and how the backends
286+
should be set up.
287+
"""
288+
289+
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig)
290+
coreml: CoreMLConfig = field(default_factory=CoreMLConfig)
291+
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
292+
qnn: QNNConfig = field(default_factory=QNNConfig)
293+
mps: MPSConfig = field(default_factory=MPSConfig)
294+
295+
296+
################################################################################
297+
################################## LlmConfig ###################################
298+
################################################################################
299+
300+
301+
@dataclass
302+
class LlmConfig:
303+
"""
304+
The overall configuration for customizing the LLM export process.
305+
"""
306+
307+
base: BaseConfig = field(default_factory=BaseConfig)
308+
model: ModelConfig = field(default_factory=ModelConfig)
309+
export: ExportConfig = field(default_factory=ExportConfig)
310+
debug: DebugConfig = field(default_factory=DebugConfig)
311+
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
312+
backend: BackendConfig = field(default_factory=BackendConfig)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.python_library(
5+
name = "llm_config",
6+
srcs = [
7+
"llm_config.py",
8+
],
9+
_is_external_target = True,
10+
base_module = "executorch.examples.models.llama.config",
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
)

0 commit comments

Comments
 (0)