7979verbosity_setting = None
8080
8181
82- EXECUTORCH_LLAMA = "et_llama"
82+ EXECUTORCH_DEFINED_MODELS = [ "stories110m" , "llama2" , "llama3" , "llama3.1" , "llama3.2" ]
8383TORCHTUNE_DEFINED_MODELS = []
8484
8585
@@ -107,23 +107,18 @@ def verbose_export():
107107
108108
109109def build_model (
110- modelname : str = "model" ,
110+ modelname : str ,
111111 extra_opts : str = "" ,
112112 * ,
113113 par_local_output : bool = False ,
114114 resource_pkg_name : str = __name__ ,
115- modelclass : str = EXECUTORCH_LLAMA ,
116115) -> str :
117- """
118- Build the model, used for tests. `modelname` arg just specifies
119- where to find the model resource files.
120- """
121116 if False : # par_local_output:
122117 output_dir_path = "par:."
123118 else :
124119 output_dir_path = "."
125120
126- argString = f"--modelclass { modelclass } --checkpoint par:{ modelname } _ckpt .pt --params par:{ modelname } _params .json { extra_opts } --output-dir { output_dir_path } "
121+ argString = f"--modelname { modelname } --checkpoint par:model_ckpt .pt --params par:model_params .json { extra_opts } --output-dir { output_dir_path } "
127122 parser = build_args_parser ()
128123 args = parser .parse_args (shlex .split (argString ))
129124 # pkg_name = resource_pkg_name
@@ -138,10 +133,10 @@ def build_args_parser() -> argparse.ArgumentParser:
138133 # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
139134 # )
140135 parser .add_argument (
141- "--modelclass " ,
142- default = EXECUTORCH_LLAMA ,
143- choices = [ EXECUTORCH_LLAMA ] + TORCHTUNE_DEFINED_MODELS ,
144- help = ' The Lllama model architecture to use. "et_llama" is a custom Llama architecture defined in ExecuTorch that supports llama2, llama3, llama3_1, llama3_2 . All other modelclasses are from TorchTune.' ,
136+ "--model " ,
137+ default = "llama3" ,
138+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
139+ help = " The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch . All other models use TorchTune model definitions." ,
145140 )
146141 parser .add_argument (
147142 "-E" ,
@@ -530,7 +525,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
530525
531526 return (
532527 _load_llama_model (
533- args .modelclass ,
528+ args .model ,
534529 checkpoint = checkpoint_path ,
535530 checkpoint_dir = checkpoint_dir ,
536531 params_path = params_path ,
@@ -553,7 +548,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
553548 args = args ,
554549 )
555550 .set_output_dir (output_dir_path )
556- .source_transform (_get_source_transforms (dtype_override , args ))
551+ .source_transform (_get_source_transforms (args . model , dtype_override , args ))
557552 )
558553
559554
@@ -771,7 +766,7 @@ def _load_llama_model_metadata(
771766
772767
773768def _load_llama_model (
774- modelclass : str = EXECUTORCH_LLAMA ,
769+ modelname : str = "llama3" ,
775770 * ,
776771 checkpoint : Optional [str ] = None ,
777772 checkpoint_dir : Optional [str ] = None ,
@@ -808,15 +803,15 @@ def _load_llama_model(
808803 f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
809804 )
810805
811- if modelclass == EXECUTORCH_LLAMA :
806+ if modelname in EXECUTORCH_DEFINED_MODELS :
812807 module_name = "llama"
813808 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
814- elif modelclass in TORCHTUNE_DEFINED_MODELS :
809+ elif modelname in TORCHTUNE_DEFINED_MODELS :
815810 raise NotImplementedError (
816811 "Torchtune Llama models are not yet supported in ExecuTorch export."
817812 )
818813 else :
819- raise ValueError (f"{ modelclass } is not a valid Llama model." )
814+ raise ValueError (f"{ modelname } is not a valid Llama model." )
820815
821816 model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
822817 module_name ,
@@ -863,7 +858,7 @@ def _load_llama_model(
863858
864859 return LLMEdgeManager (
865860 model = model ,
866- modelname = modelclass ,
861+ modelname = modelname ,
867862 max_seq_len = model .params .max_seq_len ,
868863 dtype = dtype ,
869864 use_kv_cache = use_kv_cache ,
@@ -890,7 +885,7 @@ def _load_llama_model(
890885
891886
892887def _get_source_transforms ( # noqa
893- dtype_override : Optional [DType ], args
888+ modelname : str , dtype_override : Optional [DType ], args
894889) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
895890 transforms = []
896891
@@ -920,8 +915,9 @@ def _get_source_transforms( # noqa
920915 ops that is not quantized.
921916
922917 There are cases where this may be a no-op, namely, if all linears are
923- quantized in the checkpoint.
918+ quantizedpp in the checkpoint.
924919 """
920+ modelname = f"{ modelname } _q"
925921 transforms .append (
926922 get_quant_weight_transform (args , dtype_override , verbose_export ())
927923 )
@@ -936,6 +932,7 @@ def _get_source_transforms( # noqa
936932 transformations based on the given checkpoint first. In those cases,
937933 this wil be a no-op.
938934 """
935+ modelname = f"{ modelname } _e"
939936 transforms .append (get_quant_embedding_transform (args ))
940937
941938 if args .expand_rope_table :
0 commit comments