@@ -590,7 +590,7 @@ def export_llama(
590590
591591    # If a checkpoint isn't provided for an HF OSS model, download and convert the 
592592    # weights first. 
593-     model_name  =  llm_config .base .model_class 
593+     model_name  =  llm_config .base .model_class . value 
594594    if  not  llm_config .base .checkpoint  and  model_name  in  HUGGING_FACE_REPO_IDS :
595595        repo_id  =  HUGGING_FACE_REPO_IDS [model_name ]
596596        if  model_name  ==  "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668    llm_config .export .output_dir  =  output_dir_path 
669669
670670    # Convert dtype override string to actual type. 
671-     dtype_override  =  DType [llm_config .model .dtype_override ]
671+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
672672
673673    edge_manager  =  _load_llama_model (llm_config )
674674
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702            checkpoint = llm_config .base .checkpoint ,
703703            checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ),  # type: ignore 
704704            tokenizer_path = llm_config .base .tokenizer_path ,
705-             use_spin_quant = llm_config .quantization .use_spin_quant ,
705+             use_spin_quant = (
706+                 llm_config .quantization .use_spin_quant .value 
707+                 if  llm_config .quantization .use_spin_quant 
708+                 else  None 
709+             ),
706710            embedding_quantize = llm_config .quantization .embedding_quantize ,
707711            use_shared_embedding = llm_config .model .use_shared_embedding ,
708712            quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730            vulkan = llm_config .backend .vulkan .enabled ,
727731            use_qat = llm_config .quantization .use_qat ,
728732            use_lora = llm_config .base .use_lora ,
729-             preq_mode = llm_config .base .preq_mode ,
733+             preq_mode = (
734+                 llm_config .base .preq_mode .value  if  llm_config .base .preq_mode  else  None 
735+             ),
730736            preq_group_size = llm_config .base .preq_group_size ,
731737            preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732738            local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744
739745def  get_quantizer_and_quant_params (llm_config ):
740746    pt2e_quant_params  =  get_pt2e_quantization_params (
741-         llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
747+         (
748+             llm_config .quantization .pt2e_quantize .value 
749+             if  llm_config .quantization .pt2e_quantize 
750+             else  None 
751+         ),
752+         llm_config .quantization .qmode ,
742753    )
743754    quantizers  =  get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744755    quant_dtype  =  None 
745756    if  llm_config .backend .qnn .enabled  and  llm_config .quantization .pt2e_quantize :
746757        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack and qnn" 
747758        qnn_quantizer , quant_dtype  =  get_qnn_quantizer (
748-             llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
759+             llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode 
749760        )
750761        quantizers .append (qnn_quantizer )
751762    if  llm_config .backend .coreml .enabled  and  llm_config .quantization .pt2e_quantize :
752763        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack / qnn and coreml" 
753-         coreml_quantizer  =  get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
764+         coreml_quantizer  =  get_coreml_quantizer (
765+             llm_config .quantization .pt2e_quantize .value 
766+         )
754767        quantizers .append (coreml_quantizer )
755768    if  llm_config .backend .vulkan .enabled  and  llm_config .quantization .pt2e_quantize :
756769        assert  (
757770            len (quantizers ) ==  0 
758771        ), "Should not enable both vulkan and other quantizers" 
759-         vulkan_quantizer  =  get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
772+         vulkan_quantizer  =  get_vulkan_quantizer (
773+             llm_config .quantization .pt2e_quantize .value 
774+         )
760775        quantizers .append (vulkan_quantizer )
761776    logging .info (f"Applying quantizers: { quantizers }  " )
762777    return  pt2e_quant_params , quantizers , quant_dtype 
@@ -1033,7 +1048,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10331048    )
10341049
10351050    additional_passes  =  []
1036-     if  llm_config .base .model_class  in  TORCHTUNE_DEFINED_MODELS :
1051+     if  llm_config .base .model_class . value  in  TORCHTUNE_DEFINED_MODELS :
10371052        additional_passes  =  [InitializedMutableBufferPass (["kv_cache_pos" ])]
10381053
10391054    # export_to_edge 
@@ -1072,14 +1087,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10721087            mps = llm_config .backend .mps .enabled ,
10731088            coreml = llm_config .backend .coreml .enabled ,
10741089            qnn = llm_config .backend .qnn .enabled ,
1075-             dtype_override = llm_config .model .dtype_override ,
1090+             dtype_override = llm_config .model .dtype_override . value ,
10761091            enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10771092            use_kv_cache = llm_config .model .use_kv_cache ,
10781093            embedding_quantize = llm_config .quantization .embedding_quantize ,
1079-             pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1094+             pt2e_quantize = (
1095+                 llm_config .quantization .pt2e_quantize .value 
1096+                 if  llm_config .quantization .pt2e_quantize 
1097+                 else  None 
1098+             ),
10801099            coreml_ios = llm_config .backend .coreml .ios ,
1081-             coreml_quantize = llm_config .backend .coreml .quantize ,
1082-             coreml_compute_units = llm_config .backend .coreml .compute_units ,
1100+             coreml_quantize = (
1101+                 llm_config .backend .coreml .quantize .value 
1102+                 if  llm_config .backend .coreml .quantize 
1103+                 else  None 
1104+             ),
1105+             coreml_compute_units = llm_config .backend .coreml .compute_units .value ,
10831106            use_qnn_sha = llm_config .backend .qnn .use_sha ,
10841107            num_sharding = llm_config .backend .qnn .num_sharding ,
10851108            soc_model = llm_config .backend .qnn .soc_model ,
@@ -1152,7 +1175,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11521175        An instance of LLMEdgeManager which contains the eager mode model. 
11531176    """ 
11541177
1155-     modelname  =  llm_config .base .model_class 
1178+     modelname  =  llm_config .base .model_class . value 
11561179    if  modelname  in  EXECUTORCH_DEFINED_MODELS :
11571180        module_name  =  "llama" 
11581181        model_class_name  =  "Llama2Model"   # TODO: Change to "LlamaModel" in examples/models/llama/model.py. 
@@ -1173,7 +1196,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11731196        )
11741197    )
11751198    # Convert dtype override string to actual type. 
1176-     dtype_override  =  DType [llm_config .model .dtype_override ]
1199+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
11771200
11781201    return  LLMEdgeManager (
11791202        model = model ,
0 commit comments