8181verbosity_setting = None
8282
8383
84+ EXECUTORCH_DEFINED_MODELS = ["stories110m" , "llama2" , "llama3" , "llama3_1" , "llama3_2" ]
85+ TORCHTUNE_DEFINED_MODELS = []
86+
87+
8488class WeightType (Enum ):
8589 LLAMA = "LLAMA"
8690 FAIRSEQ2 = "FAIRSEQ2"
@@ -105,7 +109,7 @@ def verbose_export():
105109
106110
107111def build_model (
108- modelname : str = "model " ,
112+ modelname : str = "llama3 " ,
109113 extra_opts : str = "" ,
110114 * ,
111115 par_local_output : bool = False ,
@@ -116,11 +120,11 @@ def build_model(
116120 else :
117121 output_dir_path = "."
118122
119- argString = f"--checkpoint par: { modelname } _ckpt .pt --params par:{ modelname } _params .json { extra_opts } --output-dir { output_dir_path } "
123+ argString = f"--model { modelname } --checkpoint par:model_ckpt .pt --params par:model_params .json { extra_opts } --output-dir { output_dir_path } "
120124 parser = build_args_parser ()
121125 args = parser .parse_args (shlex .split (argString ))
122126 # pkg_name = resource_pkg_name
123- return export_llama (modelname , args )
127+ return export_llama (args )
124128
125129
126130def build_args_parser () -> argparse .ArgumentParser :
@@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser:
130134 # parser.add_argument(
131135 # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
132136 # )
137+ parser .add_argument (
138+ "--model" ,
139+ default = "llama3" ,
140+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
141+ help = "The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions." ,
142+ )
133143 parser .add_argument (
134144 "-E" ,
135145 "--embedding-quantize" ,
@@ -480,13 +490,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
480490 return return_val
481491
482492
483- def export_llama (modelname , args ) -> str :
493+ def export_llama (args ) -> str :
484494 if args .profile_path is not None :
485495 try :
486496 from executorch .util .python_profiler import CProfilerFlameGraph
487497
488498 with CProfilerFlameGraph (args .profile_path ):
489- builder = _export_llama (modelname , args )
499+ builder = _export_llama (args )
490500 assert (
491501 filename := builder .get_saved_pte_filename ()
492502 ) is not None , "Fail to get file name from builder"
@@ -497,14 +507,14 @@ def export_llama(modelname, args) -> str:
497507 )
498508 return ""
499509 else :
500- builder = _export_llama (modelname , args )
510+ builder = _export_llama (args )
501511 assert (
502512 filename := builder .get_saved_pte_filename ()
503513 ) is not None , "Fail to get file name from builder"
504514 return filename
505515
506516
507- def _prepare_for_llama_export (modelname : str , args ) -> LLMEdgeManager :
517+ def _prepare_for_llama_export (args ) -> LLMEdgeManager :
508518 """
509519 Helper function for export_llama. Loads the model from checkpoint and params,
510520 and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
530540
531541 return (
532542 _load_llama_model (
533- modelname = modelname ,
543+ args . model ,
534544 checkpoint = checkpoint_path ,
535545 checkpoint_dir = checkpoint_dir ,
536546 params_path = params_path ,
@@ -553,7 +563,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
553563 args = args ,
554564 )
555565 .set_output_dir (output_dir_path )
556- .source_transform (_get_source_transforms (modelname , dtype_override , args ))
566+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
557567 )
558568
559569
@@ -627,12 +637,12 @@ def _validate_args(args):
627637 )
628638
629639
630- def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
640+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
631641 _validate_args (args )
632642 pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
633643
634644 # export_to_edge
635- builder_exported = _prepare_for_llama_export (modelname , args ).export ()
645+ builder_exported = _prepare_for_llama_export (args ).export ()
636646
637647 if args .export_only :
638648 exit ()
@@ -830,8 +840,8 @@ def _load_llama_model_metadata(
830840
831841
832842def _load_llama_model (
843+ modelname : str = "llama3" ,
833844 * ,
834- modelname : str = "llama2" ,
835845 checkpoint : Optional [str ] = None ,
836846 checkpoint_dir : Optional [str ] = None ,
837847 params_path : str ,
@@ -859,15 +869,27 @@ def _load_llama_model(
859869 Returns:
860870 An instance of LLMEdgeManager which contains the eager mode model.
861871 """
872+
862873 assert (
863874 checkpoint or checkpoint_dir
864875 ) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
865876 logging .info (
866877 f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
867878 )
879+
880+ if modelname in EXECUTORCH_DEFINED_MODELS :
881+ module_name = "llama"
882+ model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
883+ elif modelname in TORCHTUNE_DEFINED_MODELS :
884+ raise NotImplementedError (
885+ "Torchtune Llama models are not yet supported in ExecuTorch export."
886+ )
887+ else :
888+ raise ValueError (f"{ modelname } is not a valid Llama model." )
889+
868890 model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
869- module_name = "llama" ,
870- model_class_name = "Llama2Model" ,
891+ module_name ,
892+ model_class_name ,
871893 checkpoint = checkpoint ,
872894 checkpoint_dir = checkpoint_dir ,
873895 params = params_path ,
0 commit comments