@@ -590,7 +590,7 @@ def export_llama(
590
590
591
591
# If a checkpoint isn't provided for an HF OSS model, download and convert the
592
592
# weights first.
593
- model_name = llm_config .base .model_class
593
+ model_name = llm_config .base .model_class . value
594
594
if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
595
595
repo_id = HUGGING_FACE_REPO_IDS [model_name ]
596
596
if model_name == "qwen2_5" :
@@ -668,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
668
668
llm_config .export .output_dir = output_dir_path
669
669
670
670
# Convert dtype override string to actual type.
671
- dtype_override = DType [llm_config .model .dtype_override ]
671
+ dtype_override = DType [llm_config .model .dtype_override . value ]
672
672
673
673
edge_manager = _load_llama_model (llm_config )
674
674
@@ -702,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
702
702
checkpoint = llm_config .base .checkpoint ,
703
703
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
704
704
tokenizer_path = llm_config .base .tokenizer_path ,
705
- use_spin_quant = llm_config .quantization .use_spin_quant ,
705
+ use_spin_quant = (
706
+ llm_config .quantization .use_spin_quant .value
707
+ if llm_config .quantization .use_spin_quant
708
+ else None
709
+ ),
706
710
embedding_quantize = llm_config .quantization .embedding_quantize ,
707
711
use_shared_embedding = llm_config .model .use_shared_embedding ,
708
712
quantization_mode = llm_config .quantization .qmode ,
@@ -726,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
726
730
vulkan = llm_config .backend .vulkan .enabled ,
727
731
use_qat = llm_config .quantization .use_qat ,
728
732
use_lora = llm_config .base .use_lora ,
729
- preq_mode = llm_config .base .preq_mode ,
733
+ preq_mode = (
734
+ llm_config .base .preq_mode .value if llm_config .base .preq_mode else None
735
+ ),
730
736
preq_group_size = llm_config .base .preq_group_size ,
731
737
preq_embedding_quantize = llm_config .base .preq_embedding_quantize ,
732
738
local_global_attention = llm_config .model .local_global_attention ,
@@ -738,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
738
744
739
745
def get_quantizer_and_quant_params (llm_config ):
740
746
pt2e_quant_params = get_pt2e_quantization_params (
741
- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
747
+ (
748
+ llm_config .quantization .pt2e_quantize .value
749
+ if llm_config .quantization .pt2e_quantize
750
+ else None
751
+ ),
752
+ llm_config .quantization .qmode ,
742
753
)
743
754
quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config .export .so_library )
744
755
quant_dtype = None
745
756
if llm_config .backend .qnn .enabled and llm_config .quantization .pt2e_quantize :
746
757
assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
747
758
qnn_quantizer , quant_dtype = get_qnn_quantizer (
748
- llm_config .quantization .pt2e_quantize , llm_config .quantization .qmode
759
+ llm_config .quantization .pt2e_quantize . value , llm_config .quantization .qmode
749
760
)
750
761
quantizers .append (qnn_quantizer )
751
762
if llm_config .backend .coreml .enabled and llm_config .quantization .pt2e_quantize :
752
763
assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
753
- coreml_quantizer = get_coreml_quantizer (llm_config .quantization .pt2e_quantize )
764
+ coreml_quantizer = get_coreml_quantizer (
765
+ llm_config .quantization .pt2e_quantize .value
766
+ )
754
767
quantizers .append (coreml_quantizer )
755
768
if llm_config .backend .vulkan .enabled and llm_config .quantization .pt2e_quantize :
756
769
assert (
757
770
len (quantizers ) == 0
758
771
), "Should not enable both vulkan and other quantizers"
759
- vulkan_quantizer = get_vulkan_quantizer (llm_config .quantization .pt2e_quantize )
772
+ vulkan_quantizer = get_vulkan_quantizer (
773
+ llm_config .quantization .pt2e_quantize .value
774
+ )
760
775
quantizers .append (vulkan_quantizer )
761
776
logging .info (f"Applying quantizers: { quantizers } " )
762
777
return pt2e_quant_params , quantizers , quant_dtype
@@ -1035,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
1035
1050
)
1036
1051
1037
1052
additional_passes = []
1038
- if llm_config .base .model_class in TORCHTUNE_DEFINED_MODELS :
1053
+ if llm_config .base .model_class . value in TORCHTUNE_DEFINED_MODELS :
1039
1054
additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
1040
1055
1041
1056
# export_to_edge
@@ -1074,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
1074
1089
mps = llm_config .backend .mps .enabled ,
1075
1090
coreml = llm_config .backend .coreml .enabled ,
1076
1091
qnn = llm_config .backend .qnn .enabled ,
1077
- dtype_override = llm_config .model .dtype_override ,
1092
+ dtype_override = llm_config .model .dtype_override . value ,
1078
1093
enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
1079
1094
use_kv_cache = llm_config .model .use_kv_cache ,
1080
1095
embedding_quantize = llm_config .quantization .embedding_quantize ,
1081
- pt2e_quantize = llm_config .quantization .pt2e_quantize ,
1096
+ pt2e_quantize = (
1097
+ llm_config .quantization .pt2e_quantize .value
1098
+ if llm_config .quantization .pt2e_quantize
1099
+ else None
1100
+ ),
1082
1101
coreml_ios = llm_config .backend .coreml .ios ,
1083
- coreml_quantize = llm_config .backend .coreml .quantize ,
1084
- coreml_compute_units = llm_config .backend .coreml .compute_units ,
1102
+ coreml_quantize = (
1103
+ llm_config .backend .coreml .quantize .value
1104
+ if llm_config .backend .coreml .quantize
1105
+ else None
1106
+ ),
1107
+ coreml_compute_units = llm_config .backend .coreml .compute_units .value ,
1085
1108
use_qnn_sha = llm_config .backend .qnn .use_sha ,
1086
1109
num_sharding = llm_config .backend .qnn .num_sharding ,
1087
1110
soc_model = llm_config .backend .qnn .soc_model ,
@@ -1154,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
1154
1177
An instance of LLMEdgeManager which contains the eager mode model.
1155
1178
"""
1156
1179
1157
- modelname = llm_config .base .model_class
1180
+ modelname = llm_config .base .model_class . value
1158
1181
if modelname in EXECUTORCH_DEFINED_MODELS :
1159
1182
module_name = "llama"
1160
1183
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1175,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
1175
1198
)
1176
1199
)
1177
1200
# Convert dtype override string to actual type.
1178
- dtype_override = DType [llm_config .model .dtype_override ]
1201
+ dtype_override = DType [llm_config .model .dtype_override . value ]
1179
1202
1180
1203
return LLMEdgeManager (
1181
1204
model = model ,
0 commit comments