@@ -590,7 +590,7 @@ def export_llama(
590590
591591    # If a checkpoint isn't provided for an HF OSS model, download and convert the 
592592    # weights first. 
593-     model_name  =  llm_config .base .model_class 
593+     model_name  =  llm_config .base .model_class . value 
594594    if  not  llm_config .base .checkpoint  and  model_name  in  HUGGING_FACE_REPO_IDS :
595595        repo_id  =  HUGGING_FACE_REPO_IDS [model_name ]
596596        if  model_name  ==  "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668    llm_config .export .output_dir  =  output_dir_path 
669669
670670    # Convert dtype override string to actual type. 
671-     dtype_override  =  DType [llm_config .model .dtype_override ]
671+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
672672
673673    edge_manager  =  _load_llama_model (llm_config )
674674
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702            checkpoint = llm_config .base .checkpoint ,
703703            checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ),  # type: ignore 
704704            tokenizer_path = llm_config .base .tokenizer_path ,
705-             use_spin_quant = llm_config .quantization .use_spin_quant ,
705+             use_spin_quant = (
706+                 llm_config .quantization .use_spin_quant .value 
707+                 if  llm_config .quantization .use_spin_quant 
708+                 else  None 
709+             ),
706710            embedding_quantize = llm_config .quantization .embedding_quantize ,
707711            use_shared_embedding = llm_config .model .use_shared_embedding ,
708712            quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726730            vulkan = llm_config .backend .vulkan .enabled ,
727731            use_qat = llm_config .quantization .use_qat ,
728732            use_lora = llm_config .base .use_lora ,
729-             preq_mode = llm_config .base .preq_mode ,
733+             preq_mode = (
734+                 llm_config .base .preq_mode .value  if  llm_config .base .preq_mode  else  None 
735+             ),
730736            preq_group_size = llm_config .base .preq_group_size ,
731737            preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732738            local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738744
739745def  get_quantizer_and_quant_params (llm_config ):
740746    pt2e_quant_params  =  get_pt2e_quantization_params (
741-         llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
747+         (
748+             llm_config .quantization .pt2e_quantize .value 
749+             if  llm_config .quantization .pt2e_quantize 
750+             else  None 
751+         ),
752+         llm_config .quantization .qmode ,
742753    )
743754    quantizers  =  get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744755    quant_dtype  =  None 
745756    if  llm_config .backend .qnn .enabled  and  llm_config .quantization .pt2e_quantize :
746757        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack and qnn" 
747758        qnn_quantizer , quant_dtype  =  get_qnn_quantizer (
748-             llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
759+             llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode 
749760        )
750761        quantizers .append (qnn_quantizer )
751762    if  llm_config .backend .coreml .enabled  and  llm_config .quantization .pt2e_quantize :
752763        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack / qnn and coreml" 
753-         coreml_quantizer  =  get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
764+         coreml_quantizer  =  get_coreml_quantizer (
765+             llm_config .quantization .pt2e_quantize .value 
766+         )
754767        quantizers .append (coreml_quantizer )
755768    if  llm_config .backend .vulkan .enabled  and  llm_config .quantization .pt2e_quantize :
756769        assert  (
757770            len (quantizers ) ==  0 
758771        ), "Should not enable both vulkan and other quantizers" 
759-         vulkan_quantizer  =  get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
772+         vulkan_quantizer  =  get_vulkan_quantizer (
773+             llm_config .quantization .pt2e_quantize .value 
774+         )
760775        quantizers .append (vulkan_quantizer )
761776    logging .info (f"Applying quantizers: { quantizers }  " )
762777    return  pt2e_quant_params , quantizers , quant_dtype 
@@ -1035,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10351050    )
10361051
10371052    additional_passes  =  []
1038-     if  llm_config .base .model_class  in  TORCHTUNE_DEFINED_MODELS :
1053+     if  llm_config .base .model_class . value  in  TORCHTUNE_DEFINED_MODELS :
10391054        additional_passes  =  [InitializedMutableBufferPass (["kv_cache_pos" ])]
10401055
10411056    # export_to_edge 
@@ -1074,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10741089            mps = llm_config .backend .mps .enabled ,
10751090            coreml = llm_config .backend .coreml .enabled ,
10761091            qnn = llm_config .backend .qnn .enabled ,
1077-             dtype_override = llm_config .model .dtype_override ,
1092+             dtype_override = llm_config .model .dtype_override . value ,
10781093            enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10791094            use_kv_cache = llm_config .model .use_kv_cache ,
10801095            embedding_quantize = llm_config .quantization .embedding_quantize ,
1081-             pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1096+             pt2e_quantize = (
1097+                 llm_config .quantization .pt2e_quantize .value 
1098+                 if  llm_config .quantization .pt2e_quantize 
1099+                 else  None 
1100+             ),
10821101            coreml_ios = llm_config .backend .coreml .ios ,
1083-             coreml_quantize = llm_config .backend .coreml .quantize ,
1084-             coreml_compute_units = llm_config .backend .coreml .compute_units ,
1102+             coreml_quantize = (
1103+                 llm_config .backend .coreml .quantize .value 
1104+                 if  llm_config .backend .coreml .quantize 
1105+                 else  None 
1106+             ),
1107+             coreml_compute_units = llm_config .backend .coreml .compute_units .value ,
10851108            use_qnn_sha = llm_config .backend .qnn .use_sha ,
10861109            num_sharding = llm_config .backend .qnn .num_sharding ,
10871110            soc_model = llm_config .backend .qnn .soc_model ,
@@ -1154,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11541177        An instance of LLMEdgeManager which contains the eager mode model. 
11551178    """ 
11561179
1157-     modelname  =  llm_config .base .model_class 
1180+     modelname  =  llm_config .base .model_class . value 
11581181    if  modelname  in  EXECUTORCH_DEFINED_MODELS :
11591182        module_name  =  "llama" 
11601183        model_class_name  =  "Llama2Model"   # TODO: Change to "LlamaModel" in examples/models/llama/model.py. 
@@ -1175,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11751198        )
11761199    )
11771200    # Convert dtype override string to actual type. 
1178-     dtype_override  =  DType [llm_config .model .dtype_override ]
1201+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
11791202
11801203    return  LLMEdgeManager (
11811204        model = model ,
0 commit comments