@@ -590,7 +590,7 @@ def export_llama(
590590
591591 # If a checkpoint isn't provided for an HF OSS model, download and convert the
592592 # weights first.
593- model_name = llm_config .base .model_class
593+ model_name = llm_config .base .model_class . value
594594 if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
595595 repo_id = HUGGING_FACE_REPO_IDS [model_name ]
596596 if model_name == "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668 llm_config .export .output_dir = output_dir_path
669669
670670 # Convert dtype override string to actual type.
671- dtype_override = DType [llm_config .model .dtype_override ]
671+ dtype_override = DType [llm_config .model .dtype_override . value ]
672672
673673 edge_manager = _load_llama_model (llm_config )
674674
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702 checkpoint = llm_config .base .checkpoint ,
703703 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
704704 tokenizer_path = llm_config .base .tokenizer_path ,
705- use_spin_quant = llm_config .quantization .use_spin_quant ,
705+ use_spin_quant = (
706+ llm_config .quantization .use_spin_quant .value
707+ if llm_config .quantization .use_spin_quant
708+ else None
709+ ),
706710 embedding_quantize = llm_config .quantization .embedding_quantize ,
707711 use_shared_embedding = llm_config .model .use_shared_embedding ,
708712 quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730 vulkan = llm_config .backend .vulkan .enabled ,
727731 use_qat = llm_config .quantization .use_qat ,
728732 use_lora = llm_config .base .use_lora ,
729- preq_mode = llm_config .base .preq_mode ,
733+ preq_mode = (
734+ llm_config .base .preq_mode .value if llm_config .base .preq_mode else None
735+ ),
730736 preq_group_size = llm_config .base .preq_group_size ,
731737 preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732738 local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744
739745def get_quantizer_and_quant_params (llm_config ):
740746 pt2e_quant_params = get_pt2e_quantization_params (
741- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
747+ (
748+ llm_config .quantization .pt2e_quantize .value
749+ if llm_config .quantization .pt2e_quantize
750+ else None
751+ ),
752+ llm_config .quantization .qmode ,
742753 )
743754 quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744755 quant_dtype = None
745756 if llm_config .backend .qnn .enabled and llm_config .quantization .pt2e_quantize :
746757 assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
747758 qnn_quantizer , quant_dtype = get_qnn_quantizer (
748- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
759+ llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode
749760 )
750761 quantizers .append (qnn_quantizer )
751762 if llm_config .backend .coreml .enabled and llm_config .quantization .pt2e_quantize :
752763 assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
753- coreml_quantizer = get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
764+ coreml_quantizer = get_coreml_quantizer (
765+ llm_config .quantization .pt2e_quantize .value
766+ )
754767 quantizers .append (coreml_quantizer )
755768 if llm_config .backend .vulkan .enabled and llm_config .quantization .pt2e_quantize :
756769 assert (
757770 len (quantizers ) == 0
758771 ), "Should not enable both vulkan and other quantizers"
759- vulkan_quantizer = get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
772+ vulkan_quantizer = get_vulkan_quantizer (
773+ llm_config .quantization .pt2e_quantize .value
774+ )
760775 quantizers .append (vulkan_quantizer )
761776 logging .info (f"Applying quantizers: { quantizers } " )
762777 return pt2e_quant_params , quantizers , quant_dtype
@@ -1033,7 +1048,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10331048 )
10341049
10351050 additional_passes = []
1036- if llm_config .base .model_class in TORCHTUNE_DEFINED_MODELS :
1051+ if llm_config .base .model_class . value in TORCHTUNE_DEFINED_MODELS :
10371052 additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
10381053
10391054 # export_to_edge
@@ -1072,14 +1087,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10721087 mps = llm_config .backend .mps .enabled ,
10731088 coreml = llm_config .backend .coreml .enabled ,
10741089 qnn = llm_config .backend .qnn .enabled ,
1075- dtype_override = llm_config .model .dtype_override ,
1090+ dtype_override = llm_config .model .dtype_override . value ,
10761091 enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10771092 use_kv_cache = llm_config .model .use_kv_cache ,
10781093 embedding_quantize = llm_config .quantization .embedding_quantize ,
1079- pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1094+ pt2e_quantize = (
1095+ llm_config .quantization .pt2e_quantize .value
1096+ if llm_config .quantization .pt2e_quantize
1097+ else None
1098+ ),
10801099 coreml_ios = llm_config .backend .coreml .ios ,
1081- coreml_quantize = llm_config .backend .coreml .quantize ,
1082- coreml_compute_units = llm_config .backend .coreml .compute_units ,
1100+ coreml_quantize = (
1101+ llm_config .backend .coreml .quantize .value
1102+ if llm_config .backend .coreml .quantize
1103+ else None
1104+ ),
1105+ coreml_compute_units = llm_config .backend .coreml .compute_units .value ,
10831106 use_qnn_sha = llm_config .backend .qnn .use_sha ,
10841107 num_sharding = llm_config .backend .qnn .num_sharding ,
10851108 soc_model = llm_config .backend .qnn .soc_model ,
@@ -1152,7 +1175,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11521175 An instance of LLMEdgeManager which contains the eager mode model.
11531176 """
11541177
1155- modelname = llm_config .base .model_class
1178+ modelname = llm_config .base .model_class . value
11561179 if modelname in EXECUTORCH_DEFINED_MODELS :
11571180 module_name = "llama"
11581181 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1173,7 +1196,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11731196 )
11741197 )
11751198 # Convert dtype override string to actual type.
1176- dtype_override = DType [llm_config .model .dtype_override ]
1199+ dtype_override = DType [llm_config .model .dtype_override . value ]
11771200
11781201 return LLMEdgeManager (
11791202 model = model ,
0 commit comments