@@ -590,7 +590,7 @@ def export_llama(
590590
591591 # If a checkpoint isn't provided for an HF OSS model, download and convert the
592592 # weights first.
593- model_name = llm_config .base .model_class
593+ model_name = llm_config .base .model_class . value
594594 if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
595595 repo_id = HUGGING_FACE_REPO_IDS [model_name ]
596596 if model_name == "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668 llm_config .export .output_dir = output_dir_path
669669
670670 # Convert dtype override string to actual type.
671- dtype_override = DType [llm_config .model .dtype_override ]
671+ dtype_override = DType [llm_config .model .dtype_override . value ]
672672
673673 edge_manager = _load_llama_model (llm_config )
674674
@@ -702,7 +702,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702 checkpoint = llm_config .base .checkpoint ,
703703 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
704704 tokenizer_path = llm_config .base .tokenizer_path ,
705- use_spin_quant = llm_config .quantization .use_spin_quant ,
705+ use_spin_quant = llm_config .quantization .use_spin_quant . value if llm_config . quantization . use_spin_quant else None ,
706706 embedding_quantize = llm_config .quantization .embedding_quantize ,
707707 use_shared_embedding = llm_config .model .use_shared_embedding ,
708708 quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +726,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726726 vulkan = llm_config .backend .vulkan .enabled ,
727727 use_qat = llm_config .quantization .use_qat ,
728728 use_lora = llm_config .base .use_lora ,
729- preq_mode = llm_config .base .preq_mode ,
729+ preq_mode = llm_config .base .preq_mode . value if llm_config . base . preq_mode else None ,
730730 preq_group_size = llm_config .base .preq_group_size ,
731731 preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732732 local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +738,25 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738738
739739def get_quantizer_and_quant_params (llm_config ):
740740 pt2e_quant_params = get_pt2e_quantization_params (
741- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
741+ llm_config .quantization .pt2e_quantize . value if llm_config . quantization . pt2e_quantize else None , llm_config .quantization .qmode
742742 )
743743 quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744744 quant_dtype = None
745745 if llm_config .backend .qnn .enabled and llm_config .quantization .pt2e_quantize :
746746 assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
747747 qnn_quantizer , quant_dtype = get_qnn_quantizer (
748- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
748+ llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode
749749 )
750750 quantizers .append (qnn_quantizer )
751751 if llm_config .backend .coreml .enabled and llm_config .quantization .pt2e_quantize :
752752 assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
753- coreml_quantizer = get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
753+ coreml_quantizer = get_coreml_quantizer (llm_config .quantization .pt2e_quantize . value )
754754 quantizers .append (coreml_quantizer )
755755 if llm_config .backend .vulkan .enabled and llm_config .quantization .pt2e_quantize :
756756 assert (
757757 len (quantizers ) == 0
758758 ), "Should not enable both vulkan and other quantizers"
759- vulkan_quantizer = get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
759+ vulkan_quantizer = get_vulkan_quantizer (llm_config .quantization .pt2e_quantize . value )
760760 quantizers .append (vulkan_quantizer )
761761 logging .info (f"Applying quantizers: { quantizers } " )
762762 return pt2e_quant_params , quantizers , quant_dtype
@@ -1033,7 +1033,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10331033 )
10341034
10351035 additional_passes = []
1036- if llm_config .base .model_class in TORCHTUNE_DEFINED_MODELS :
1036+ if llm_config .base .model_class . value in TORCHTUNE_DEFINED_MODELS :
10371037 additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
10381038
10391039 # export_to_edge
@@ -1072,14 +1072,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10721072 mps = llm_config .backend .mps .enabled ,
10731073 coreml = llm_config .backend .coreml .enabled ,
10741074 qnn = llm_config .backend .qnn .enabled ,
1075- dtype_override = llm_config .model .dtype_override ,
1075+ dtype_override = llm_config .model .dtype_override . value ,
10761076 enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10771077 use_kv_cache = llm_config .model .use_kv_cache ,
10781078 embedding_quantize = llm_config .quantization .embedding_quantize ,
1079- pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1079+ pt2e_quantize = llm_config .quantization .pt2e_quantize . value if llm_config . quantization . pt2e_quantize else None ,
10801080 coreml_ios = llm_config .backend .coreml .ios ,
1081- coreml_quantize = llm_config .backend .coreml .quantize ,
1082- coreml_compute_units = llm_config .backend .coreml .compute_units ,
1081+ coreml_quantize = llm_config .backend .coreml .quantize . value if llm_config . backend . coreml . quantize else None ,
1082+ coreml_compute_units = llm_config .backend .coreml .compute_units . value ,
10831083 use_qnn_sha = llm_config .backend .qnn .use_sha ,
10841084 num_sharding = llm_config .backend .qnn .num_sharding ,
10851085 soc_model = llm_config .backend .qnn .soc_model ,
@@ -1152,7 +1152,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11521152 An instance of LLMEdgeManager which contains the eager mode model.
11531153 """
11541154
1155- modelname = llm_config .base .model_class
1155+ modelname = llm_config .base .model_class . value
11561156 if modelname in EXECUTORCH_DEFINED_MODELS :
11571157 module_name = "llama"
11581158 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1173,7 +1173,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11731173 )
11741174 )
11751175 # Convert dtype override string to actual type.
1176- dtype_override = DType [llm_config .model .dtype_override ]
1176+ dtype_override = DType [llm_config .model .dtype_override . value ]
11771177
11781178 return LLMEdgeManager (
11791179 model = model ,
0 commit comments