@@ -123,26 +123,19 @@ def verbose_export():
123123
124124
125125def build_model (
126- modelname : str = "llama3" ,
127- extra_opts : str = "" ,
128- * ,
129- par_local_output : bool = False ,
130- resource_pkg_name : str = __name__ ,
126+ model : str ,
127+ checkpoint : str ,
128+ params : str ,
129+ output_dir : Optional [ str ] = "." ,
130+ extra_opts : Optional [ str ] = "" ,
131131) -> str :
132- if False : # par_local_output:
133- output_dir_path = "par:."
134- else :
135- output_dir_path = "."
136-
137- argString = f"--model { modelname } --checkpoint par:model_ckpt.pt --params par:model_params.json { extra_opts } --output-dir { output_dir_path } "
132+ argString = f"--model { model } --checkpoint { checkpoint } --params { params } { extra_opts } --output-dir { output_dir } "
138133 parser = build_args_parser ()
139134 args = parser .parse_args (shlex .split (argString ))
140- # pkg_name = resource_pkg_name
141135 return export_llama (args )
142136
143137
144138def build_args_parser () -> argparse .ArgumentParser :
145- ckpt_dir = f"{ Path (__file__ ).absolute ().parent .as_posix ()} "
146139 parser = argparse .ArgumentParser ()
147140 parser .add_argument ("-o" , "--output-dir" , default = "." , help = "output directory" )
148141 # parser.add_argument(
@@ -191,8 +184,8 @@ def build_args_parser() -> argparse.ArgumentParser:
191184 parser .add_argument (
192185 "-c" ,
193186 "--checkpoint" ,
194- default = f" { ckpt_dir } /params/demo_rand_params.pth" ,
195- help = "checkpoint path " ,
187+ required = False ,
188+ help = "Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights. " ,
196189 )
197190
198191 parser .add_argument (
@@ -273,8 +266,8 @@ def build_args_parser() -> argparse.ArgumentParser:
273266 parser .add_argument (
274267 "-p" ,
275268 "--params" ,
276- default = f" { ckpt_dir } /params/demo_config.json" ,
277- help = "config.json " ,
269+ required = False ,
270+ help = "Config file for model parameters. When not provided, the model will fallback on default values defined in examples/models/llama/model_args.py. " ,
278271 )
279272 parser .add_argument (
280273 "--optimized_rotation_path" ,
@@ -561,7 +554,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561554 checkpoint_dir = (
562555 canonical_path (args .checkpoint_dir ) if args .checkpoint_dir else None
563556 )
564- params_path = canonical_path (args .params )
557+ params_path = canonical_path (args .params ) if args . params else None
565558 output_dir_path = canonical_path (args .output_dir , dir = True )
566559 weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
567560
@@ -960,7 +953,7 @@ def _load_llama_model(
960953 * ,
961954 checkpoint : Optional [str ] = None ,
962955 checkpoint_dir : Optional [str ] = None ,
963- params_path : str ,
956+ params_path : Optional [ str ] = None ,
964957 use_kv_cache : bool = False ,
965958 use_sdpa_with_kv_cache : bool = False ,
966959 generate_full_logits : bool = False ,
@@ -987,13 +980,6 @@ def _load_llama_model(
987980 An instance of LLMEdgeManager which contains the eager mode model.
988981 """
989982
990- assert (
991- checkpoint or checkpoint_dir
992- ) and params_path , "Both checkpoint/checkpoint_dir and params can't be empty"
993- logging .info (
994- f"Loading model with checkpoint={ checkpoint } , params={ params_path } , use_kv_cache={ use_kv_cache } , weight_type={ weight_type } "
995- )
996-
997983 if modelname in EXECUTORCH_DEFINED_MODELS :
998984 module_name = "llama"
999985 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
0 commit comments