2929
3030from executorch .devtools .etrecord import generate_etrecord as generate_etrecord_func
3131
32+ from executorch .examples .models .llama .config .llm_config import LlmConfig
3233from executorch .examples .models .llama .config .llm_config_utils import (
3334 convert_args_to_llm_config ,
3435)
@@ -156,7 +157,8 @@ def build_model(
156157 argString = f"--model { model } --checkpoint { checkpoint } --params { params } { extra_opts } --output-dir { output_dir } "
157158 parser = build_args_parser ()
158159 args = parser .parse_args (shlex .split (argString ))
159- return export_llama (args )
160+ llm_config = convert_args_to_llm_config (args )
161+ return export_llama (llm_config )
160162
161163
162164def parse_list_of_ints (s ):
@@ -578,15 +580,10 @@ def export_llama(
578580) -> str :
579581 if isinstance (export_options , argparse .Namespace ):
580582 # Legacy CLI.
581- args = export_options
582583 llm_config = convert_args_to_llm_config (export_options )
583584 elif isinstance (export_options , DictConfig ):
584585 # Hydra CLI.
585586 llm_config = export_options
586- # Create an args object for backward compatibility during transition
587- args = argparse .Namespace ()
588- for key , value in llm_config .items ():
589- setattr (args , key , value )
590587 else :
591588 raise ValueError (
592589 "Input to export_llama must be either of type argparse.Namespace or LlmConfig"
@@ -625,7 +622,7 @@ def export_llama(
625622 from executorch .util .python_profiler import CProfilerFlameGraph
626623
627624 with CProfilerFlameGraph (llm_config .debug .profile_path ):
628- builder = _export_llama (llm_config , args )
625+ builder = _export_llama (llm_config )
629626 assert (
630627 filename := builder .get_saved_pte_filename ()
631628 ) is not None , "Fail to get file name from builder"
@@ -636,14 +633,14 @@ def export_llama(
636633 )
637634 return ""
638635 else :
639- builder = _export_llama (llm_config , args )
636+ builder = _export_llama (llm_config )
640637 assert (
641638 filename := builder .get_saved_pte_filename ()
642639 ) is not None , "Fail to get file name from builder"
643640 return filename
644641
645642
646- def _prepare_for_llama_export (llm_config , args ) -> LLMEdgeManager :
643+ def _prepare_for_llama_export (llm_config : LlmConfig ) -> LLMEdgeManager :
647644 """
648645 Helper function for export_llama. Loads the model from checkpoint and params,
649646 and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -671,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
671668 dtype_override = DType [llm_config .model .dtype_override ]
672669
673670 edge_manager = _load_llama_model (
674- llm_config . base . model_class ,
671+ llm_config ,
675672 checkpoint = checkpoint_path ,
676673 checkpoint_dir = checkpoint_dir ,
677674 params_path = params_path ,
@@ -694,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
694691 dtype_override = dtype_override ,
695692 use_qnn = llm_config .backend .qnn .enabled ,
696693 export_only = llm_config .export .export_only ,
697- args = args ,
698694 )
699695
700696 # At this point, the model is loaded in the default fp32.
@@ -805,10 +801,6 @@ def _qmode_type(value):
805801
806802
807803def _validate_args (llm_config ):
808- """
809- TODO: Combine all the backends under --backend args
810- """
811-
812804 if llm_config .export .max_context_length < llm_config .export .max_seq_length :
813805 raise ValueError (
814806 f"max_context_length { llm_config .export .max_context_length } must be >= max_seq_len { llm_config .export .max_seq_length } . max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
@@ -1057,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901
10571049 return builder
10581050
10591051
1060- def _export_llama (llm_config , args ) -> LLMEdgeManager : # noqa: C901
1052+ def _export_llama (llm_config : LlmConfig ) -> LLMEdgeManager : # noqa: C901
10611053 _validate_args (llm_config )
10621054
10631055 pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
@@ -1069,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
10691061 additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
10701062
10711063 # export_to_edge
1072- builder_exported = _prepare_for_llama_export (llm_config , args ).export ()
1064+ builder_exported = _prepare_for_llama_export (llm_config ).export ()
10731065 builder_exported .run_canonical_optimizations ()
10741066 modelname = builder_exported .modelname
10751067
@@ -1177,7 +1169,7 @@ def _load_llama_model_metadata(
11771169
11781170
11791171def _load_llama_model (
1180- modelname : str = "llama3" ,
1172+ llm_config : LlmConfig ,
11811173 * ,
11821174 checkpoint : Optional [str ] = None ,
11831175 checkpoint_dir : Optional [str ] = None ,
@@ -1201,7 +1193,6 @@ def _load_llama_model(
12011193 dtype_override : Optional [DType ] = None ,
12021194 use_qnn : bool = False ,
12031195 export_only : bool = False ,
1204- args ,
12051196) -> "LLMEdgeManager" :
12061197 """
12071198 A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
@@ -1210,6 +1201,7 @@ def _load_llama_model(
12101201 An instance of LLMEdgeManager which contains the eager mode model.
12111202 """
12121203
1204+ modelname = llm_config .base .model_class
12131205 if modelname in EXECUTORCH_DEFINED_MODELS :
12141206 module_name = "llama"
12151207 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1222,26 +1214,11 @@ def _load_llama_model(
12221214 else :
12231215 raise ValueError (f"{ modelname } is not a valid Llama model." )
12241216
1225- torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
1226-
12271217 model , example_inputs , example_kwarg_inputs , dynamic_shapes = (
12281218 EagerModelFactory .create_model (
12291219 module_name ,
12301220 model_class_name ,
1231- checkpoint = checkpoint ,
1232- checkpoint_dir = checkpoint_dir ,
1233- params = params_path ,
1234- use_kv_cache = use_kv_cache ,
1235- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
1236- generate_full_logits = generate_full_logits ,
1237- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
1238- max_seq_len = max_seq_len ,
1239- max_context_len = max_context_len ,
1240- enable_dynamic_shape = enable_dynamic_shape ,
1241- input_prune_map_path = input_prune_map_path ,
1242- output_prune_map_path = output_prune_map_path ,
1243- dtype = torch_dtype ,
1244- args = args ,
1221+ llm_config = llm_config ,
12451222 )
12461223 )
12471224
@@ -1498,9 +1475,9 @@ def _get_source_transforms( # noqa
14981475 return transforms
14991476
15001477
1501- def get_llama_model (args ):
1502- _validate_args (args )
1503- e_mgr = _prepare_for_llama_export (args )
1478+ def get_llama_model (llm_config : LlmConfig ):
1479+ _validate_args (llm_config )
1480+ e_mgr = _prepare_for_llama_export (llm_config )
15041481 model = (
15051482 e_mgr .model .eval ().to (device = "cuda" )
15061483 if torch .cuda .is_available ()
0 commit comments