diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index 201e3a5414a..9acd633fb21 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -26,19 +26,19 @@ class ModelType(str, Enum): - STORIES110M = "stories110m" - LLAMA2 = "llama2" - LLAMA3 = "llama3" - LLAMA3_1 = "llama3_1" - LLAMA3_2 = "llama3_2" - LLAMA3_2_VISION = "llama3_2_vision" - STATIC_LLAMA = "static_llama" - QWEN2_5 = "qwen2_5" - QWEN3_0_6B = "qwen3-0_6b" - QWEN3_1_7B = "qwen3-1_7b" - QWEN3_4B = "qwen3-4b" - PHI_4_MINI = "phi_4_mini" - SMOLLM2 = "smollm2" + stories110m = "stories110m" + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_2_vision = "llama3_2_vision" + static_llama = "static_llama" + qwen2_5 = "qwen2_5" + qwen3_0_6b = "qwen3-0_6b" + qwen3_1_7b = "qwen3-1_7b" + qwen3_4b = "qwen3-4b" + phi_4_mini = "phi_4_mini" + smollm2 = "smollm2" class PreqMode(str, Enum): @@ -49,8 +49,8 @@ class PreqMode(str, Enum): are still around to preserve backward compatibility. """ - PREQ_8DA4W = "8da4w" - PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" + preq_8da4w = "8da4w" + preq_8da4w_out_8da8w = "8da4w_output_8da8w" @dataclass @@ -82,7 +82,7 @@ class BaseConfig: are loaded. """ - model_class: ModelType = ModelType.LLAMA3 + model_class: ModelType = ModelType.llama3 params: Optional[str] = None checkpoint: Optional[str] = None checkpoint_dir: Optional[str] = None @@ -107,9 +107,9 @@ class DtypeOverride(str, Enum): is not recommended. """ - FP32 = "fp32" - FP16 = "fp16" - BF16 = "bf16" + fp32 = "fp32" + fp16 = "fp16" + bf16 = "bf16" @dataclass @@ -147,7 +147,7 @@ class ModelConfig: [16] pattern specifies all layers have a sliding window of 16. """ - dtype_override: DtypeOverride = DtypeOverride.FP32 + dtype_override: DtypeOverride = DtypeOverride.fp32 enable_dynamic_shape: bool = True use_shared_embedding: bool = False use_sdpa_with_kv_cache: bool = False @@ -270,22 +270,22 @@ class Pt2eQuantize(str, Enum): and is source transform-based. """ - XNNPACK_DYNAMIC = "xnnpack_dynamic" - XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" - QNN_8A8W = "qnn_8a8w" - QNN_16A16W = "qnn_16a16w" - QNN_16A4W = "qnn_16a4w" - COREML_C4W = "coreml_c4w" - COREML_8A_C8W = "coreml_8a_c8w" - COREML_8A_C4W = "coreml_8a_c4w" - COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" - COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" - VULKAN_8W = "vulkan_8w" + xnnpack_dynamic = "xnnpack_dynamic" + xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4" + qnn_8a8w = "qnn_8a8w" + qnn_16a16w = "qnn_16a16w" + qnn_16a4w = "qnn_16a4w" + coreml_c4w = "coreml_c4w" + coreml_8a_c8w = "coreml_8a_c8w" + coreml_8a_c4w = "coreml_8a_c4w" + coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w" + coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w" + vulkan_8w = "vulkan_8w" class SpinQuant(str, Enum): - CUDA = "cuda" - NATIVE = "native" + cuda = "cuda" + native = "native" @dataclass @@ -378,15 +378,15 @@ class XNNPackConfig: class CoreMLQuantize(str, Enum): - B4W = "b4w" - C4W = "c4w" + b4w = "b4w" + c4w = "c4w" class CoreMLComputeUnit(str, Enum): - CPU_ONLY = "cpu_only" - CPU_AND_GPU = "cpu_and_gpu" - CPU_AND_NE = "cpu_and_ne" - ALL = "all" + cpu_only = "cpu_only" + cpu_and_gpu = "cpu_and_gpu" + cpu_and_ne = "cpu_and_ne" + all = "all" @dataclass @@ -400,7 +400,7 @@ class CoreMLConfig: preserve_sdpa: bool = False quantize: Optional[CoreMLQuantize] = None ios: int = 15 - compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY + compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only def __post_init__(self): if self.ios not in (15, 16, 17, 18): diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 88b79d30eb2..334f3ace712 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -590,7 +590,7 @@ def export_llama( # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - model_name = llm_config.base.model_class + model_name = llm_config.base.model_class.value if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: repo_id = HUGGING_FACE_REPO_IDS[model_name] if model_name == "qwen2_5": @@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: llm_config.export.output_dir = output_dir_path # Convert dtype override string to actual type. - dtype_override = DType[llm_config.model.dtype_override] + dtype_override = DType[llm_config.model.dtype_override.value] edge_manager = _load_llama_model(llm_config) @@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore tokenizer_path=llm_config.base.tokenizer_path, - use_spin_quant=llm_config.quantization.use_spin_quant, + use_spin_quant=( + llm_config.quantization.use_spin_quant.value + if llm_config.quantization.use_spin_quant + else None + ), embedding_quantize=llm_config.quantization.embedding_quantize, use_shared_embedding=llm_config.model.use_shared_embedding, quantization_mode=llm_config.quantization.qmode, @@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: vulkan=llm_config.backend.vulkan.enabled, use_qat=llm_config.quantization.use_qat, use_lora=llm_config.base.use_lora, - preq_mode=llm_config.base.preq_mode, + preq_mode=( + llm_config.base.preq_mode.value if llm_config.base.preq_mode else None + ), preq_group_size=llm_config.base.preq_group_size, preq_embedding_quantize=llm_config.base.preq_embedding_quantize, local_global_attention=llm_config.model.local_global_attention, @@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode + ( + llm_config.quantization.pt2e_quantize.value + if llm_config.quantization.pt2e_quantize + else None + ), + llm_config.quantization.qmode, ) quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode + llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer( + llm_config.quantization.pt2e_quantize.value + ) quantizers.append(coreml_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer( + llm_config.quantization.pt2e_quantize.value + ) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -1035,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 ) additional_passes = [] - if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge @@ -1074,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 mps=llm_config.backend.mps.enabled, coreml=llm_config.backend.coreml.enabled, qnn=llm_config.backend.qnn.enabled, - dtype_override=llm_config.model.dtype_override, + dtype_override=llm_config.model.dtype_override.value, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, use_kv_cache=llm_config.model.use_kv_cache, embedding_quantize=llm_config.quantization.embedding_quantize, - pt2e_quantize=llm_config.quantization.pt2e_quantize, + pt2e_quantize=( + llm_config.quantization.pt2e_quantize.value + if llm_config.quantization.pt2e_quantize + else None + ), coreml_ios=llm_config.backend.coreml.ios, - coreml_quantize=llm_config.backend.coreml.quantize, - coreml_compute_units=llm_config.backend.coreml.compute_units, + coreml_quantize=( + llm_config.backend.coreml.quantize.value + if llm_config.backend.coreml.quantize + else None + ), + coreml_compute_units=llm_config.backend.coreml.compute_units.value, use_qnn_sha=llm_config.backend.qnn.use_sha, num_sharding=llm_config.backend.qnn.num_sharding, soc_model=llm_config.backend.qnn.soc_model, @@ -1154,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": An instance of LLMEdgeManager which contains the eager mode model. """ - modelname = llm_config.base.model_class + modelname = llm_config.base.model_class.value if modelname in EXECUTORCH_DEFINED_MODELS: module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. @@ -1175,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": ) ) # Convert dtype override string to actual type. - dtype_override = DType[llm_config.model.dtype_override] + dtype_override = DType[llm_config.model.dtype_override.value] return LLMEdgeManager( model=model, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ec9646be6f4..efea80dde2f 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -157,7 +157,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - model_name = str(self.llm_config.base.model_class) + model_name = self.llm_config.base.model_class.value assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor @@ -328,10 +328,10 @@ def get_example_inputs_kvcache_sdpa(self): def _transform_for_pre_quantization(self, checkpoint, model_args): assert self.llm_config.base.preq_mode, "preq_mode must be specified" - assert self.llm_config.base.preq_mode in [ + assert self.llm_config.base.preq_mode.value in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." + ], f"Quantization mode {self.llm_config.base.preq_mode.value} is not compatible with SpinQuant." assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" assert self.llm_config.model.dtype_override, "dtype_override must be specified" @@ -351,7 +351,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): } # Transform the output layer first if needed. - if self.llm_config.base.preq_mode == "8da4w_output_8da8w": + if self.llm_config.base.preq_mode.value == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( transform_output_linear_for_pre_quantization, ) @@ -359,14 +359,14 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=mapping[self.llm_config.model.dtype_override], + dtype=mapping[self.llm_config.model.dtype_override.value], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, self.llm_config.base.preq_group_size, - mapping[self.llm_config.model.dtype_override], + mapping[self.llm_config.model.dtype_override.value], ) embedding_bit_width, embedding_group_size = None, None @@ -390,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - mapping[self.llm_config.model.dtype_override], + mapping[self.llm_config.model.dtype_override.value], int(embedding_bit_width), embedding_group_size, ) diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 1f230233867..7d17b7819d3 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -51,9 +51,20 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: f.write( """ base: + model_class: llama2 tokenizer_path: /path/to/tokenizer.json + preq_mode: preq_8da4w +model: + dtype_override: fp16 export: max_seq_length: 256 +quantization: + pt2e_quantize: xnnpack_dynamic + use_spin_quant: cuda +backend: + coreml: + quantize: c4w + compute_units: cpu_and_gpu """ ) config_file = f.name @@ -69,7 +80,22 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: self.assertEqual( called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" ) + self.assertEqual(called_config["base"]["model_class"], "llama2") + self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") + self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") self.assertEqual(called_config["export"]["max_seq_length"], 256) + self.assertEqual( + called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" + ) + self.assertEqual( + called_config["quantization"]["use_spin_quant"].value, "cuda" + ) + self.assertEqual( + called_config["backend"]["coreml"]["quantize"].value, "c4w" + ) + self.assertEqual( + called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" + ) finally: os.unlink(config_file)