Skip to content

Commit 209fd7f

Browse files
committed
Update on "Use llm_config instead of args in export_llama functions"
Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927) [ghstack-poisoned]
2 parents d9c70c2 + c88c5de commit 209fd7f

File tree

3 files changed

+69
-81
lines changed

3 files changed

+69
-81
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
TosaPipelineMI,
2323
)
2424

25+
from executorch.examples.models.llama.config.llm_config_utils import (
26+
convert_args_to_llm_config,
27+
)
2528
from executorch.examples.models.llama.export_llama_lib import (
2629
build_args_parser,
2730
get_llama_model,
@@ -89,8 +92,9 @@ def prepare_model(self):
8992
]
9093
parser = build_args_parser()
9194
args = parser.parse_args(args)
95+
llm_config = convert_args_to_llm_config(args)
9296

93-
llama_model, llama_inputs, llama_meta = get_llama_model(args)
97+
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
9498

9599
return llama_model, llama_inputs, llama_meta
96100

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
3131

32+
from executorch.examples.models.llama.config.llm_config import LlmConfig
3233
from executorch.examples.models.llama.config.llm_config_utils import (
3334
convert_args_to_llm_config,
3435
)
@@ -156,7 +157,8 @@ def build_model(
156157
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
157158
parser = build_args_parser()
158159
args = parser.parse_args(shlex.split(argString))
159-
return export_llama(args)
160+
llm_config = convert_args_to_llm_config(args)
161+
return export_llama(llm_config)
160162

161163

162164
def parse_list_of_ints(s):
@@ -578,15 +580,10 @@ def export_llama(
578580
) -> str:
579581
if isinstance(export_options, argparse.Namespace):
580582
# Legacy CLI.
581-
args = export_options
582583
llm_config = convert_args_to_llm_config(export_options)
583584
elif isinstance(export_options, DictConfig):
584585
# Hydra CLI.
585586
llm_config = export_options
586-
# Create an args object for backward compatibility during transition
587-
args = argparse.Namespace()
588-
for key, value in llm_config.items():
589-
setattr(args, key, value)
590587
else:
591588
raise ValueError(
592589
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
@@ -625,7 +622,7 @@ def export_llama(
625622
from executorch.util.python_profiler import CProfilerFlameGraph
626623

627624
with CProfilerFlameGraph(llm_config.debug.profile_path):
628-
builder = _export_llama(llm_config, args)
625+
builder = _export_llama(llm_config)
629626
assert (
630627
filename := builder.get_saved_pte_filename()
631628
) is not None, "Fail to get file name from builder"
@@ -636,14 +633,14 @@ def export_llama(
636633
)
637634
return ""
638635
else:
639-
builder = _export_llama(llm_config, args)
636+
builder = _export_llama(llm_config)
640637
assert (
641638
filename := builder.get_saved_pte_filename()
642639
) is not None, "Fail to get file name from builder"
643640
return filename
644641

645642

646-
def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
643+
def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
647644
"""
648645
Helper function for export_llama. Loads the model from checkpoint and params,
649646
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
@@ -671,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
671668
dtype_override = DType[llm_config.model.dtype_override]
672669

673670
edge_manager = _load_llama_model(
674-
llm_config.base.model_class,
671+
llm_config,
675672
checkpoint=checkpoint_path,
676673
checkpoint_dir=checkpoint_dir,
677674
params_path=params_path,
@@ -694,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
694691
dtype_override=dtype_override,
695692
use_qnn=llm_config.backend.qnn.enabled,
696693
export_only=llm_config.export.export_only,
697-
args=args,
698694
)
699695

700696
# At this point, the model is loaded in the default fp32.
@@ -805,10 +801,6 @@ def _qmode_type(value):
805801

806802

807803
def _validate_args(llm_config):
808-
"""
809-
TODO: Combine all the backends under --backend args
810-
"""
811-
812804
if llm_config.export.max_context_length < llm_config.export.max_seq_length:
813805
raise ValueError(
814806
f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
@@ -1057,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901
10571049
return builder
10581050

10591051

1060-
def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
1052+
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10611053
_validate_args(llm_config)
10621054

10631055
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
@@ -1069,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
10691061
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
10701062

10711063
# export_to_edge
1072-
builder_exported = _prepare_for_llama_export(llm_config, args).export()
1064+
builder_exported = _prepare_for_llama_export(llm_config).export()
10731065
builder_exported.run_canonical_optimizations()
10741066
modelname = builder_exported.modelname
10751067

@@ -1177,7 +1169,7 @@ def _load_llama_model_metadata(
11771169

11781170

11791171
def _load_llama_model(
1180-
modelname: str = "llama3",
1172+
llm_config: LlmConfig,
11811173
*,
11821174
checkpoint: Optional[str] = None,
11831175
checkpoint_dir: Optional[str] = None,
@@ -1201,7 +1193,6 @@ def _load_llama_model(
12011193
dtype_override: Optional[DType] = None,
12021194
use_qnn: bool = False,
12031195
export_only: bool = False,
1204-
args,
12051196
) -> "LLMEdgeManager":
12061197
"""
12071198
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
@@ -1210,6 +1201,7 @@ def _load_llama_model(
12101201
An instance of LLMEdgeManager which contains the eager mode model.
12111202
"""
12121203

1204+
modelname = llm_config.base.model_class
12131205
if modelname in EXECUTORCH_DEFINED_MODELS:
12141206
module_name = "llama"
12151207
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
@@ -1222,26 +1214,11 @@ def _load_llama_model(
12221214
else:
12231215
raise ValueError(f"{modelname} is not a valid Llama model.")
12241216

1225-
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1226-
12271217
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
12281218
EagerModelFactory.create_model(
12291219
module_name,
12301220
model_class_name,
1231-
checkpoint=checkpoint,
1232-
checkpoint_dir=checkpoint_dir,
1233-
params=params_path,
1234-
use_kv_cache=use_kv_cache,
1235-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
1236-
generate_full_logits=generate_full_logits,
1237-
fairseq2=weight_type == WeightType.FAIRSEQ2,
1238-
max_seq_len=max_seq_len,
1239-
max_context_len=max_context_len,
1240-
enable_dynamic_shape=enable_dynamic_shape,
1241-
input_prune_map_path=input_prune_map_path,
1242-
output_prune_map_path=output_prune_map_path,
1243-
dtype=torch_dtype,
1244-
args=args,
1221+
llm_config=llm_config,
12451222
)
12461223
)
12471224

@@ -1498,9 +1475,9 @@ def _get_source_transforms( # noqa
14981475
return transforms
14991476

15001477

1501-
def get_llama_model(args):
1502-
_validate_args(args)
1503-
e_mgr = _prepare_for_llama_export(args)
1478+
def get_llama_model(llm_config: LlmConfig):
1479+
_validate_args(llm_config)
1480+
e_mgr = _prepare_for_llama_export(llm_config)
15041481
model = (
15051482
e_mgr.model.eval().to(device="cuda")
15061483
if torch.cuda.is_available()

0 commit comments

Comments
 (0)