@@ -590,7 +590,7 @@ def export_llama(
590590
591591 # If a checkpoint isn't provided for an HF OSS model, download and convert the
592592 # weights first.
593- model_name = llm_config .base .model_class
593+ model_name = llm_config .base .model_class . value
594594 if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
595595 repo_id = HUGGING_FACE_REPO_IDS [model_name ]
596596 if model_name == "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668 llm_config .export .output_dir = output_dir_path
669669
670670 # Convert dtype override string to actual type.
671- dtype_override = DType [llm_config .model .dtype_override ]
671+ dtype_override = DType [llm_config .model .dtype_override . value ]
672672
673673 edge_manager = _load_llama_model (llm_config )
674674
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702 checkpoint = llm_config .base .checkpoint ,
703703 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
704704 tokenizer_path = llm_config .base .tokenizer_path ,
705- use_spin_quant = llm_config .quantization .use_spin_quant ,
705+ use_spin_quant = (
706+ llm_config .quantization .use_spin_quant .value
707+ if llm_config .quantization .use_spin_quant
708+ else None
709+ ),
706710 embedding_quantize = llm_config .quantization .embedding_quantize ,
707711 use_shared_embedding = llm_config .model .use_shared_embedding ,
708712 quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730 vulkan = llm_config .backend .vulkan .enabled ,
727731 use_qat = llm_config .quantization .use_qat ,
728732 use_lora = llm_config .base .use_lora ,
729- preq_mode = llm_config .base .preq_mode ,
733+ preq_mode = (
734+ llm_config .base .preq_mode .value if llm_config .base .preq_mode else None
735+ ),
730736 preq_group_size = llm_config .base .preq_group_size ,
731737 preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732738 local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744
739745def get_quantizer_and_quant_params (llm_config ):
740746 pt2e_quant_params = get_pt2e_quantization_params (
741- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
747+ (
748+ llm_config .quantization .pt2e_quantize .value
749+ if llm_config .quantization .pt2e_quantize
750+ else None
751+ ),
752+ llm_config .quantization .qmode ,
742753 )
743754 quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744755 quant_dtype = None
745756 if llm_config .backend .qnn .enabled and llm_config .quantization .pt2e_quantize :
746757 assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
747758 qnn_quantizer , quant_dtype = get_qnn_quantizer (
748- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
759+ llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode
749760 )
750761 quantizers .append (qnn_quantizer )
751762 if llm_config .backend .coreml .enabled and llm_config .quantization .pt2e_quantize :
752763 assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
753- coreml_quantizer = get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
764+ coreml_quantizer = get_coreml_quantizer (
765+ llm_config .quantization .pt2e_quantize .value
766+ )
754767 quantizers .append (coreml_quantizer )
755768 if llm_config .backend .vulkan .enabled and llm_config .quantization .pt2e_quantize :
756769 assert (
757770 len (quantizers ) == 0
758771 ), "Should not enable both vulkan and other quantizers"
759- vulkan_quantizer = get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
772+ vulkan_quantizer = get_vulkan_quantizer (
773+ llm_config .quantization .pt2e_quantize .value
774+ )
760775 quantizers .append (vulkan_quantizer )
761776 logging .info (f"Applying quantizers: { quantizers } " )
762777 return pt2e_quant_params , quantizers , quant_dtype
@@ -1035,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10351050 )
10361051
10371052 additional_passes = []
1038- if llm_config .base .model_class in TORCHTUNE_DEFINED_MODELS :
1053+ if llm_config .base .model_class . value in TORCHTUNE_DEFINED_MODELS :
10391054 additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
10401055
10411056 # export_to_edge
@@ -1074,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10741089 mps = llm_config .backend .mps .enabled ,
10751090 coreml = llm_config .backend .coreml .enabled ,
10761091 qnn = llm_config .backend .qnn .enabled ,
1077- dtype_override = llm_config .model .dtype_override ,
1092+ dtype_override = llm_config .model .dtype_override . value ,
10781093 enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10791094 use_kv_cache = llm_config .model .use_kv_cache ,
10801095 embedding_quantize = llm_config .quantization .embedding_quantize ,
1081- pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1096+ pt2e_quantize = (
1097+ llm_config .quantization .pt2e_quantize .value
1098+ if llm_config .quantization .pt2e_quantize
1099+ else None
1100+ ),
10821101 coreml_ios = llm_config .backend .coreml .ios ,
1083- coreml_quantize = llm_config .backend .coreml .quantize ,
1084- coreml_compute_units = llm_config .backend .coreml .compute_units ,
1102+ coreml_quantize = (
1103+ llm_config .backend .coreml .quantize .value
1104+ if llm_config .backend .coreml .quantize
1105+ else None
1106+ ),
1107+ coreml_compute_units = llm_config .backend .coreml .compute_units .value ,
10851108 use_qnn_sha = llm_config .backend .qnn .use_sha ,
10861109 num_sharding = llm_config .backend .qnn .num_sharding ,
10871110 soc_model = llm_config .backend .qnn .soc_model ,
@@ -1154,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11541177 An instance of LLMEdgeManager which contains the eager mode model.
11551178 """
11561179
1157- modelname = llm_config .base .model_class
1180+ modelname = llm_config .base .model_class . value
11581181 if modelname in EXECUTORCH_DEFINED_MODELS :
11591182 module_name = "llama"
11601183 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1175,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11751198 )
11761199 )
11771200 # Convert dtype override string to actual type.
1178- dtype_override = DType [llm_config .model .dtype_override ]
1201+ dtype_override = DType [llm_config .model .dtype_override . value ]
11791202
11801203 return LLMEdgeManager (
11811204 model = model ,
0 commit comments