@@ -661,36 +661,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
661661 canonical_path (llm_config .base .params ) if llm_config .base .params else None
662662 )
663663 output_dir_path = canonical_path (llm_config .export .output_dir , dir = True )
664- weight_type = WeightType .FAIRSEQ2 if llm_config .base .fairseq2 else WeightType .LLAMA
665664
666- # Convert dtype override string to actual type
665+ llm_config .base .checkpoint = checkpoint_path
666+ llm_config .base .checkpoint_dir = checkpoint_dir
667+ llm_config .base .params = params_path
668+ llm_config .export .output_dir = output_dir_path
669+
670+ # Convert dtype override string to actual type.
667671 dtype_override = DType [llm_config .model .dtype_override ]
668672
669- edge_manager = _load_llama_model (
670- llm_config ,
671- checkpoint = checkpoint_path ,
672- checkpoint_dir = checkpoint_dir ,
673- params_path = params_path ,
674- use_kv_cache = llm_config .model .use_kv_cache ,
675- use_sdpa_with_kv_cache = llm_config .model .use_sdpa_with_kv_cache ,
676- generate_full_logits = llm_config .debug .generate_full_logits ,
677- weight_type = weight_type ,
678- enable_dynamic_shape = llm_config .model .enable_dynamic_shape ,
679- calibration_tasks = llm_config .quantization .calibration_tasks ,
680- calibration_limit = llm_config .quantization .calibration_limit ,
681- calibration_seq_length = llm_config .quantization .calibration_seq_length ,
682- calibration_data = llm_config .quantization .calibration_data ,
683- tokenizer_path = llm_config .base .tokenizer_path ,
684- verbose = llm_config .debug .verbose ,
685- max_seq_len = llm_config .export .max_seq_length ,
686- max_context_len = llm_config .export .max_context_length ,
687- input_prune_map_path = llm_config .model .input_prune_map ,
688- output_prune_map_path = llm_config .model .output_prune_map ,
689- metadata_str = llm_config .base .metadata ,
690- dtype_override = dtype_override ,
691- use_qnn = llm_config .backend .qnn .enabled ,
692- export_only = llm_config .export .export_only ,
693- )
673+ edge_manager = _load_llama_model (llm_config )
694674
695675 # At this point, the model is loaded in the default fp32.
696676
@@ -1167,32 +1147,7 @@ def _load_llama_model_metadata(
11671147 return metadata
11681148
11691149
1170- def _load_llama_model (
1171- llm_config : LlmConfig ,
1172- * ,
1173- checkpoint : Optional [str ] = None ,
1174- checkpoint_dir : Optional [str ] = None ,
1175- params_path : Optional [str ] = None ,
1176- use_kv_cache : bool = False ,
1177- use_sdpa_with_kv_cache : bool = False ,
1178- generate_full_logits : bool = False ,
1179- weight_type : WeightType = WeightType .LLAMA ,
1180- enable_dynamic_shape : bool = False ,
1181- calibration_tasks : Optional [List [str ]] = None ,
1182- calibration_limit : Optional [int ] = None ,
1183- calibration_seq_length : Optional [int ] = None ,
1184- calibration_data : Optional [str ] = None ,
1185- tokenizer_path : Optional [str ] = None ,
1186- verbose : bool = False ,
1187- max_seq_len : int = 128 ,
1188- max_context_len : int = 128 ,
1189- input_prune_map_path : Optional [str ] = None ,
1190- output_prune_map_path : Optional [str ] = None ,
1191- metadata_str : Optional [str ] = None ,
1192- dtype_override : Optional [DType ] = None ,
1193- use_qnn : bool = False ,
1194- export_only : bool = False ,
1195- ) -> "LLMEdgeManager" :
1150+ def _load_llama_model (llm_config : LlmConfig ) -> "LLMEdgeManager" :
11961151 """
11971152 A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
11981153 can help further lower the model to ExecuTorch.
@@ -1220,31 +1175,33 @@ def _load_llama_model(
12201175 llm_config = llm_config ,
12211176 )
12221177 )
1178+ # Convert dtype override string to actual type.
1179+ dtype_override = DType [llm_config .model .dtype_override ]
12231180
12241181 return LLMEdgeManager (
12251182 model = model ,
12261183 modelname = modelname ,
12271184 max_seq_len = model .max_seq_len , # type: ignore
12281185 dtype = dtype_override ,
1229- use_kv_cache = use_kv_cache ,
1230- generate_full_logits = generate_full_logits ,
1186+ use_kv_cache = llm_config . model . use_kv_cache ,
1187+ generate_full_logits = llm_config . debug . generate_full_logits ,
12311188 example_inputs = example_inputs ,
12321189 example_kwarg_inputs = example_kwarg_inputs ,
12331190 dynamic_shapes = dynamic_shapes ,
1234- enable_dynamic_shape = enable_dynamic_shape ,
1235- calibration_tasks = calibration_tasks ,
1236- calibration_limit = calibration_limit ,
1237- calibration_seq_length = calibration_seq_length ,
1238- calibration_data = calibration_data ,
1239- tokenizer_path = tokenizer_path ,
1240- use_legacy_export = use_qnn ,
1241- save_exported_program = export_only ,
1242- verbose = verbose ,
1191+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
1192+ calibration_tasks = llm_config . quantization . calibration_tasks ,
1193+ calibration_limit = llm_config . quantization . calibration_limit ,
1194+ calibration_seq_length = llm_config . quantization . calibration_seq_length ,
1195+ calibration_data = llm_config . quantization . calibration_data ,
1196+ tokenizer_path = llm_config . base . tokenizer_path ,
1197+ use_legacy_export = llm_config . backend . qnn . enabled ,
1198+ save_exported_program = llm_config . export . export_only ,
1199+ verbose = llm_config . debug . verbose ,
12431200 metadata = _load_llama_model_metadata (
1244- weight_type ,
1245- use_kv_cache ,
1246- use_sdpa_with_kv_cache ,
1247- enable_dynamic_shape ,
1201+ WeightType . FAIRSEQ2 if llm_config . base . fairseq2 else WeightType . LLAMA ,
1202+ llm_config . model . use_kv_cache ,
1203+ llm_config . model . use_sdpa_with_kv_cache ,
1204+ llm_config . model . enable_dynamic_shape ,
12481205 # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
12491206 # `Union[Tensor, Module]`.
12501207 model .max_seq_len ,
@@ -1257,7 +1214,7 @@ def _load_llama_model(
12571214 # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
12581215 # Module]`.
12591216 model .vocab_size ,
1260- metadata_str ,
1217+ llm_config . base . metadata ,
12611218 ),
12621219 )
12631220
0 commit comments