@@ -590,7 +590,7 @@ def export_llama(
590590
591591    # If a checkpoint isn't provided for an HF OSS model, download and convert the 
592592    # weights first. 
593-     model_name  =  llm_config .base .model_class 
593+     model_name  =  llm_config .base .model_class . value 
594594    if  not  llm_config .base .checkpoint  and  model_name  in  HUGGING_FACE_REPO_IDS :
595595        repo_id  =  HUGGING_FACE_REPO_IDS [model_name ]
596596        if  model_name  ==  "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668668    llm_config .export .output_dir  =  output_dir_path 
669669
670670    # Convert dtype override string to actual type. 
671-     dtype_override  =  DType [llm_config .model .dtype_override ]
671+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
672672
673673    edge_manager  =  _load_llama_model (llm_config )
674674
@@ -702,7 +702,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702702            checkpoint = llm_config .base .checkpoint ,
703703            checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ),  # type: ignore 
704704            tokenizer_path = llm_config .base .tokenizer_path ,
705-             use_spin_quant = llm_config .quantization .use_spin_quant ,
705+             use_spin_quant = llm_config .quantization .use_spin_quant . value   if   llm_config . quantization . use_spin_quant   else   None ,
706706            embedding_quantize = llm_config .quantization .embedding_quantize ,
707707            use_shared_embedding = llm_config .model .use_shared_embedding ,
708708            quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +726,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726726            vulkan = llm_config .backend .vulkan .enabled ,
727727            use_qat = llm_config .quantization .use_qat ,
728728            use_lora = llm_config .base .use_lora ,
729-             preq_mode = llm_config .base .preq_mode ,
729+             preq_mode = llm_config .base .preq_mode . value   if   llm_config . base . preq_mode   else   None ,
730730            preq_group_size = llm_config .base .preq_group_size ,
731731            preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732732            local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +738,25 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738738
739739def  get_quantizer_and_quant_params (llm_config ):
740740    pt2e_quant_params  =  get_pt2e_quantization_params (
741-         llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
741+         llm_config .quantization .pt2e_quantize . value   if   llm_config . quantization . pt2e_quantize   else   None , llm_config .quantization .qmode 
742742    )
743743    quantizers  =  get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744744    quant_dtype  =  None 
745745    if  llm_config .backend .qnn .enabled  and  llm_config .quantization .pt2e_quantize :
746746        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack and qnn" 
747747        qnn_quantizer , quant_dtype  =  get_qnn_quantizer (
748-             llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode 
748+             llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode 
749749        )
750750        quantizers .append (qnn_quantizer )
751751    if  llm_config .backend .coreml .enabled  and  llm_config .quantization .pt2e_quantize :
752752        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack / qnn and coreml" 
753-         coreml_quantizer  =  get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
753+         coreml_quantizer  =  get_coreml_quantizer (llm_config .quantization .pt2e_quantize . value )
754754        quantizers .append (coreml_quantizer )
755755    if  llm_config .backend .vulkan .enabled  and  llm_config .quantization .pt2e_quantize :
756756        assert  (
757757            len (quantizers ) ==  0 
758758        ), "Should not enable both vulkan and other quantizers" 
759-         vulkan_quantizer  =  get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
759+         vulkan_quantizer  =  get_vulkan_quantizer (llm_config .quantization .pt2e_quantize . value )
760760        quantizers .append (vulkan_quantizer )
761761    logging .info (f"Applying quantizers: { quantizers }  " )
762762    return  pt2e_quant_params , quantizers , quant_dtype 
@@ -1033,7 +1033,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10331033    )
10341034
10351035    additional_passes  =  []
1036-     if  llm_config .base .model_class  in  TORCHTUNE_DEFINED_MODELS :
1036+     if  llm_config .base .model_class . value  in  TORCHTUNE_DEFINED_MODELS :
10371037        additional_passes  =  [InitializedMutableBufferPass (["kv_cache_pos" ])]
10381038
10391039    # export_to_edge 
@@ -1072,14 +1072,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager:  # noqa: C901
10721072            mps = llm_config .backend .mps .enabled ,
10731073            coreml = llm_config .backend .coreml .enabled ,
10741074            qnn = llm_config .backend .qnn .enabled ,
1075-             dtype_override = llm_config .model .dtype_override ,
1075+             dtype_override = llm_config .model .dtype_override . value ,
10761076            enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
10771077            use_kv_cache = llm_config .model .use_kv_cache ,
10781078            embedding_quantize = llm_config .quantization .embedding_quantize ,
1079-             pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1079+             pt2e_quantize = llm_config .quantization .pt2e_quantize . value   if   llm_config . quantization . pt2e_quantize   else   None ,
10801080            coreml_ios = llm_config .backend .coreml .ios ,
1081-             coreml_quantize = llm_config .backend .coreml .quantize ,
1082-             coreml_compute_units = llm_config .backend .coreml .compute_units ,
1081+             coreml_quantize = llm_config .backend .coreml .quantize . value   if   llm_config . backend . coreml . quantize   else   None ,
1082+             coreml_compute_units = llm_config .backend .coreml .compute_units . value ,
10831083            use_qnn_sha = llm_config .backend .qnn .use_sha ,
10841084            num_sharding = llm_config .backend .qnn .num_sharding ,
10851085            soc_model = llm_config .backend .qnn .soc_model ,
@@ -1152,7 +1152,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11521152        An instance of LLMEdgeManager which contains the eager mode model. 
11531153    """ 
11541154
1155-     modelname  =  llm_config .base .model_class 
1155+     modelname  =  llm_config .base .model_class . value 
11561156    if  modelname  in  EXECUTORCH_DEFINED_MODELS :
11571157        module_name  =  "llama" 
11581158        model_class_name  =  "Llama2Model"   # TODO: Change to "LlamaModel" in examples/models/llama/model.py. 
@@ -1173,7 +1173,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11731173        )
11741174    )
11751175    # Convert dtype override string to actual type. 
1176-     dtype_override  =  DType [llm_config .model .dtype_override ]
1176+     dtype_override  =  DType [llm_config .model .dtype_override . value ]
11771177
11781178    return  LLMEdgeManager (
11791179        model = model ,
0 commit comments