Skip to content

Commit 1431167

Browse files
committed
Back to model name
1 parent 623019e commit 1431167

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
verbosity_setting = None
8080

8181

82-
EXECUTORCH_LLAMA = "et_llama"
82+
EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3.1", "llama3.2"]
8383
TORCHTUNE_DEFINED_MODELS = []
8484

8585

@@ -107,23 +107,18 @@ def verbose_export():
107107

108108

109109
def build_model(
110-
modelname: str = "model",
110+
modelname: str,
111111
extra_opts: str = "",
112112
*,
113113
par_local_output: bool = False,
114114
resource_pkg_name: str = __name__,
115-
modelclass: str = EXECUTORCH_LLAMA,
116115
) -> str:
117-
"""
118-
Build the model, used for tests. `modelname` arg just specifies
119-
where to find the model resource files.
120-
"""
121116
if False: # par_local_output:
122117
output_dir_path = "par:."
123118
else:
124119
output_dir_path = "."
125120

126-
argString = f"--modelclass {modelclass} --checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}"
121+
argString = f"--modelname {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
127122
parser = build_args_parser()
128123
args = parser.parse_args(shlex.split(argString))
129124
# pkg_name = resource_pkg_name
@@ -138,10 +133,10 @@ def build_args_parser() -> argparse.ArgumentParser:
138133
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
139134
# )
140135
parser.add_argument(
141-
"--modelclass",
142-
default=EXECUTORCH_LLAMA,
143-
choices=[EXECUTORCH_LLAMA] + TORCHTUNE_DEFINED_MODELS,
144-
help='The Lllama model architecture to use. "et_llama" is a custom Llama architecture defined in ExecuTorch that supports llama2, llama3, llama3_1, llama3_2. All other modelclasses are from TorchTune.',
136+
"--model",
137+
default="llama3",
138+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
139+
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
145140
)
146141
parser.add_argument(
147142
"-E",
@@ -530,7 +525,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
530525

531526
return (
532527
_load_llama_model(
533-
args.modelclass,
528+
args.model,
534529
checkpoint=checkpoint_path,
535530
checkpoint_dir=checkpoint_dir,
536531
params_path=params_path,
@@ -553,7 +548,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
553548
args=args,
554549
)
555550
.set_output_dir(output_dir_path)
556-
.source_transform(_get_source_transforms(dtype_override, args))
551+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
557552
)
558553

559554

@@ -771,7 +766,7 @@ def _load_llama_model_metadata(
771766

772767

773768
def _load_llama_model(
774-
modelclass: str = EXECUTORCH_LLAMA,
769+
modelname: str = "llama3",
775770
*,
776771
checkpoint: Optional[str] = None,
777772
checkpoint_dir: Optional[str] = None,
@@ -808,15 +803,15 @@ def _load_llama_model(
808803
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
809804
)
810805

811-
if modelclass == EXECUTORCH_LLAMA:
806+
if modelname in EXECUTORCH_DEFINED_MODELS:
812807
module_name = "llama"
813808
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
814-
elif modelclass in TORCHTUNE_DEFINED_MODELS:
809+
elif modelname in TORCHTUNE_DEFINED_MODELS:
815810
raise NotImplementedError(
816811
"Torchtune Llama models are not yet supported in ExecuTorch export."
817812
)
818813
else:
819-
raise ValueError(f"{modelclass} is not a valid Llama model.")
814+
raise ValueError(f"{modelname} is not a valid Llama model.")
820815

821816
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
822817
module_name,
@@ -863,7 +858,7 @@ def _load_llama_model(
863858

864859
return LLMEdgeManager(
865860
model=model,
866-
modelname=modelclass,
861+
modelname=modelname,
867862
max_seq_len=model.params.max_seq_len,
868863
dtype=dtype,
869864
use_kv_cache=use_kv_cache,
@@ -890,7 +885,7 @@ def _load_llama_model(
890885

891886

892887
def _get_source_transforms( # noqa
893-
dtype_override: Optional[DType], args
888+
modelname: str, dtype_override: Optional[DType], args
894889
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
895890
transforms = []
896891

@@ -920,8 +915,9 @@ def _get_source_transforms( # noqa
920915
ops that is not quantized.
921916
922917
There are cases where this may be a no-op, namely, if all linears are
923-
quantized in the checkpoint.
918+
quantizedpp in the checkpoint.
924919
"""
920+
modelname = f"{modelname}_q"
925921
transforms.append(
926922
get_quant_weight_transform(args, dtype_override, verbose_export())
927923
)
@@ -936,6 +932,7 @@ def _get_source_transforms( # noqa
936932
transformations based on the given checkpoint first. In those cases,
937933
this wil be a no-op.
938934
"""
935+
modelname = f"{modelname}_e"
939936
transforms.append(get_quant_embedding_transform(args))
940937

941938
if args.expand_rope_table:

0 commit comments

Comments
 (0)