Skip to content

Commit 535d133

Browse files
committed
Update
[ghstack-poisoned]
2 parents c318411 + 199ff95 commit 535d133

File tree

3 files changed

+74
-80
lines changed

3 files changed

+74
-80
lines changed

.github/workflows/android-perf.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ jobs:
228228
export.output_name="${OUT_ET_MODEL_NAME}.pte" \
229229
model.use_kv_cache=true \
230230
model.dtype_override=fp32 \
231-
base.preq_embedding_quantize='8,0' \
231+
base.preq_embedding_quantize=\'8,0\' \
232232
quantization.use_spin_quant=native \
233233
base.metadata='{"get_bos_id":128000,"get_eos_ids":[128009,128001]}'
234234
ls -lh "${OUT_ET_MODEL_NAME}.pte"
@@ -249,7 +249,7 @@ jobs:
249249
base.use_lora=16 \
250250
base.preq_mode="8da4w_output_8da8w" \
251251
base.preq_group_size=32 \
252-
base.preq_embedding_quantize='8,0' \
252+
base.preq_embedding_quantize=\'8,0\' \
253253
model.use_sdpa_with_kv_cache=true \
254254
model.use_kv_cache=true \
255255
backend.xnnpack.enabled=true \
@@ -287,7 +287,7 @@ jobs:
287287
backend.xnnpack.extended_ops=true \
288288
quantization.qmode=8da4w \
289289
quantization.group_size=32 \
290-
quantization.embedding_quantize='8,0' \
290+
quantization.embedding_quantize=\'8,0\' \
291291
base.metadata='{"get_bos_id":128000,"get_eos_ids":[128009,128001]}' \
292292
export.output_name="${OUT_ET_MODEL_NAME}.pte"
293293
ls -lh "${OUT_ET_MODEL_NAME}.pte"
@@ -325,7 +325,7 @@ jobs:
325325
backend.xnnpack.extended_ops=true \
326326
quantization.qmode=8da4w \
327327
quantization.group_size=32 \
328-
quantization.embedding_quantize='8,0' \
328+
quantization.embedding_quantize=\'8,0\' \
329329
base.metadata='{"get_bos_id":151644,"get_eos_ids":[151645]}' \
330330
export.output_name="${OUT_ET_MODEL_NAME}.pte"
331331
ls -lh "${OUT_ET_MODEL_NAME}.pte"

.github/workflows/apple-perf.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ jobs:
237237
export.output_name="${OUT_ET_MODEL_NAME}.pte" \
238238
model.use_kv_cache=true \
239239
model.dtype_override=fp32 \
240-
base.preq_embedding_quantize='8,0' \
240+
base.preq_embedding_quantize=\'8,0\' \
241241
quantization.use_spin_quant=native \
242242
base.metadata='{"get_bos_id":128000,"get_eos_ids":[128009,128001]}'
243243
ls -lh "${OUT_ET_MODEL_NAME}.pte"
@@ -258,7 +258,7 @@ jobs:
258258
base.use_lora=16 \
259259
base.preq_mode="8da4w_output_8da8w" \
260260
base.preq_group_size=32 \
261-
base.preq_embedding_quantize='8,0' \
261+
base.preq_embedding_quantize=\'8,0\' \
262262
model.use_sdpa_with_kv_cache=true \
263263
model.use_kv_cache=true \
264264
backend.xnnpack.enabled=true \
@@ -296,7 +296,7 @@ jobs:
296296
backend.xnnpack.extended_ops=true \
297297
quantization.qmode=8da4w \
298298
quantization.group_size=32 \
299-
quantization.embedding_quantize='8,0' \
299+
quantization.embedding_quantize=\'8,0\' \
300300
base.metadata='{"get_bos_id":128000,"get_eos_ids":[128009,128001]}' \
301301
export.output_name="${OUT_ET_MODEL_NAME}.pte"
302302
ls -lh "${OUT_ET_MODEL_NAME}.pte"
@@ -330,7 +330,7 @@ jobs:
330330
backend.xnnpack.extended_ops=true \
331331
quantization.qmode=8da4w \
332332
quantization.group_size=32 \
333-
quantization.embedding_quantize='8,0' \
333+
quantization.embedding_quantize=\'8,0\' \
334334
base.metadata='{"get_bos_id":151644,"get_eos_ids":[151645]}' \
335335
export.output_name="${OUT_ET_MODEL_NAME}.pte"
336336
ls -lh "${OUT_ET_MODEL_NAME}.pte"

extension/llm/export/config/llm_config.py

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,16 @@ class ModelType(str, Enum):
4545
smollm2 = "smollm2"
4646

4747

48+
class PreqMode(str, Enum):
49+
"""
50+
If you are dealing with pre-quantized checkpoints, this used to
51+
be the way to specify them. Now you don't need to specify these
52+
options if you use a TorchAo-prequantized checkpoint, but they
53+
are still around to preserve backward compatibility.
54+
"""
4855

49-
PREQ_MODE_OPTIONS = [
50-
"8da4w",
51-
"8da4w_output_8da8w",
52-
]
56+
preq_8da4w = "8da4w"
57+
preq_8da4w_out_8da8w = "8da4w_output_8da8w"
5358

5459

5560
@dataclass
@@ -81,36 +86,34 @@ class BaseConfig:
8186
are loaded.
8287
"""
8388

84-
model_class: str = "llama3"
89+
model_class: ModelType = ModelType.llama3
8590
params: Optional[str] = None
8691
checkpoint: Optional[str] = None
8792
checkpoint_dir: Optional[str] = None
8893
tokenizer_path: Optional[str] = None
8994
metadata: Optional[str] = None
9095
use_lora: int = 0
9196
fairseq2: bool = False
92-
preq_mode: Optional[str] = None
97+
preq_mode: Optional[PreqMode] = None
9398
preq_group_size: int = 32
9499
preq_embedding_quantize: str = "8,0"
95100

96-
def __post_init__(self):
97-
if self.model_class not in MODEL_TYPE_OPTIONS:
98-
raise ValueError(f"model_class must be one of {MODEL_TYPE_OPTIONS}, got '{self.model_class}'")
99-
100-
if self.preq_mode is not None and self.preq_mode not in PREQ_MODE_OPTIONS:
101-
raise ValueError(f"preq_mode must be one of {PREQ_MODE_OPTIONS}, got '{self.preq_mode}'")
102-
103101

104102
################################################################################
105103
################################# ModelConfig ##################################
106104
################################################################################
107105

108106

109-
DTYPE_OVERRIDE_OPTIONS = [
110-
"fp32",
111-
"fp16",
112-
"bf16",
113-
]
107+
class DtypeOverride(str, Enum):
108+
"""
109+
DType of the model. Highly recommended to use "fp32", unless you want to
110+
export without a backend, in which case you can also use "bf16". "fp16"
111+
is not recommended.
112+
"""
113+
114+
fp32 = "fp32"
115+
fp16 = "fp16"
116+
bf16 = "bf16"
114117

115118

116119
@dataclass
@@ -148,7 +151,7 @@ class ModelConfig:
148151
[16] pattern specifies all layers have a sliding window of 16.
149152
"""
150153

151-
dtype_override: str = "fp32"
154+
dtype_override: DtypeOverride = DtypeOverride.fp32
152155
enable_dynamic_shape: bool = True
153156
use_shared_embedding: bool = False
154157
use_sdpa_with_kv_cache: bool = False
@@ -161,9 +164,6 @@ class ModelConfig:
161164
local_global_attention: Optional[List[int]] = None
162165

163166
def __post_init__(self):
164-
if self.dtype_override not in DTYPE_OVERRIDE_OPTIONS:
165-
raise ValueError(f"dtype_override must be one of {DTYPE_OVERRIDE_OPTIONS}, got '{self.dtype_override}'")
166-
167167
self._validate_attention_sink()
168168
self._validate_local_global_attention()
169169

@@ -265,25 +265,31 @@ class DebugConfig:
265265
################################################################################
266266

267267

268-
PT2E_QUANTIZE_OPTIONS = [
269-
"xnnpack_dynamic",
270-
"xnnpack_dynamic_qc4",
271-
"qnn_8a8w",
272-
"qnn_16a16w",
273-
"qnn_16a4w",
274-
"coreml_c4w",
275-
"coreml_8a_c8w",
276-
"coreml_8a_c4w",
277-
"coreml_baseline_8a_c8w",
278-
"coreml_baseline_8a_c4w",
279-
"vulkan_8w",
280-
]
268+
class Pt2eQuantize(str, Enum):
269+
"""
270+
Type of backend-specific Pt2e quantization strategy to use.
271+
272+
Pt2e uses a different quantization library that is graph-based
273+
compared to `qmode`, which is also specified in the QuantizationConfig
274+
and is source transform-based.
275+
"""
281276

277+
xnnpack_dynamic = "xnnpack_dynamic"
278+
xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
279+
qnn_8a8w = "qnn_8a8w"
280+
qnn_16a16w = "qnn_16a16w"
281+
qnn_16a4w = "qnn_16a4w"
282+
coreml_c4w = "coreml_c4w"
283+
coreml_8a_c8w = "coreml_8a_c8w"
284+
coreml_8a_c4w = "coreml_8a_c4w"
285+
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
286+
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
287+
vulkan_8w = "vulkan_8w"
282288

283-
SPIN_QUANT_OPTIONS = [
284-
"cuda",
285-
"native",
286-
]
289+
290+
class SpinQuant(str, Enum):
291+
cuda = "cuda"
292+
native = "native"
287293

288294

289295
@dataclass
@@ -318,22 +324,16 @@ class QuantizationConfig:
318324

319325
qmode: Optional[str] = None
320326
embedding_quantize: Optional[str] = None
321-
pt2e_quantize: Optional[str] = None
327+
pt2e_quantize: Optional[Pt2eQuantize] = None
322328
group_size: Optional[int] = None
323-
use_spin_quant: Optional[str] = None
329+
use_spin_quant: Optional[SpinQuant] = None
324330
use_qat: bool = False
325331
calibration_tasks: Optional[List[str]] = None
326332
calibration_limit: Optional[int] = None
327333
calibration_seq_length: Optional[int] = None
328334
calibration_data: str = "Once upon a time"
329335

330336
def __post_init__(self):
331-
if self.pt2e_quantize is not None and self.pt2e_quantize not in PT2E_QUANTIZE_OPTIONS:
332-
raise ValueError(f"pt2e_quantize must be one of {PT2E_QUANTIZE_OPTIONS}, got '{self.pt2e_quantize}'")
333-
334-
if self.use_spin_quant is not None and self.use_spin_quant not in SPIN_QUANT_OPTIONS:
335-
raise ValueError(f"use_spin_quant must be one of {SPIN_QUANT_OPTIONS}, got '{self.use_spin_quant}'")
336-
337337
if self.qmode:
338338
self._validate_qmode()
339339

@@ -381,18 +381,16 @@ class XNNPackConfig:
381381
extended_ops: bool = False
382382

383383

384-
COREML_QUANTIZE_OPTIONS = [
385-
"b4w",
386-
"c4w",
387-
]
384+
class CoreMLQuantize(str, Enum):
385+
b4w = "b4w"
386+
c4w = "c4w"
388387

389388

390-
COREML_COMPUTE_UNIT_OPTIONS = [
391-
"cpu_only",
392-
"cpu_and_gpu",
393-
"cpu_and_ne",
394-
"all",
395-
]
389+
class CoreMLComputeUnit(str, Enum):
390+
cpu_only = "cpu_only"
391+
cpu_and_gpu = "cpu_and_gpu"
392+
cpu_and_ne = "cpu_and_ne"
393+
all = "all"
396394

397395

398396
@dataclass
@@ -404,17 +402,11 @@ class CoreMLConfig:
404402
enabled: bool = False
405403
enable_state: bool = False
406404
preserve_sdpa: bool = False
407-
quantize: Optional[str] = None
405+
quantize: Optional[CoreMLQuantize] = None
408406
ios: int = 15
409-
compute_units: str = "cpu_only"
407+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only
410408

411409
def __post_init__(self):
412-
if self.quantize is not None and self.quantize not in COREML_QUANTIZE_OPTIONS:
413-
raise ValueError(f"quantize must be one of {COREML_QUANTIZE_OPTIONS}, got '{self.quantize}'")
414-
415-
if self.compute_units not in COREML_COMPUTE_UNIT_OPTIONS:
416-
raise ValueError(f"compute_units must be one of {COREML_COMPUTE_UNIT_OPTIONS}, got '{self.compute_units}'")
417-
418410
if self.ios not in (15, 16, 17, 18):
419411
raise ValueError(f"Invalid coreml ios version: {self.ios}")
420412

@@ -493,7 +485,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
493485

494486
# BaseConfig
495487
if hasattr(args, "model"):
496-
llm_config.base.model_class = args.model
488+
llm_config.base.model_class = ModelType(args.model)
497489
if hasattr(args, "params"):
498490
llm_config.base.params = args.params
499491
if hasattr(args, "checkpoint"):
@@ -511,15 +503,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
511503

512504
# PreqMode settings
513505
if hasattr(args, "preq_mode") and args.preq_mode:
514-
llm_config.base.preq_mode = args.preq_mode
506+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
515507
if hasattr(args, "preq_group_size"):
516508
llm_config.base.preq_group_size = args.preq_group_size
517509
if hasattr(args, "preq_embedding_quantize"):
518510
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
519511

520512
# ModelConfig
521513
if hasattr(args, "dtype_override"):
522-
llm_config.model.dtype_override = args.dtype_override
514+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
523515
if hasattr(args, "enable_dynamic_shape"):
524516
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
525517
if hasattr(args, "use_shared_embedding"):
@@ -561,11 +553,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
561553
if hasattr(args, "embedding_quantize"):
562554
llm_config.quantization.embedding_quantize = args.embedding_quantize
563555
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
564-
llm_config.quantization.pt2e_quantize = args.pt2e_quantize
556+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
565557
if hasattr(args, "group_size"):
566558
llm_config.quantization.group_size = args.group_size
567559
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
568-
llm_config.quantization.use_spin_quant = args.use_spin_quant
560+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
569561
if hasattr(args, "use_qat"):
570562
llm_config.quantization.use_qat = args.use_qat
571563
if hasattr(args, "calibration_tasks"):
@@ -593,11 +585,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
593585
args, "coreml_preserve_sdpa", False
594586
)
595587
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
596-
llm_config.backend.coreml.quantize = args.coreml_quantize
588+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
597589
if hasattr(args, "coreml_ios"):
598590
llm_config.backend.coreml.ios = args.coreml_ios
599591
if hasattr(args, "coreml_compute_units"):
600-
llm_config.backend.coreml.compute_units = args.coreml_compute_units
592+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
593+
args.coreml_compute_units
594+
)
601595

602596
# Vulkan
603597
if hasattr(args, "vulkan"):

0 commit comments

Comments
 (0)