diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index d0a18d88b9d..1031768ec0b 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -22,6 +22,7 @@ TosaPipelineMI, ) +from executorch.examples.models.llama.config.llm_config_utils import convert_args_to_llm_config from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_llama_model, @@ -89,8 +90,9 @@ def prepare_model(self): ] parser = build_args_parser() args = parser.parse_args(args) + llm_config = convert_args_to_llm_config(args) - llama_model, llama_inputs, llama_meta = get_llama_model(args) + llama_model, llama_inputs, llama_meta = get_llama_model(llm_config) return llama_model, llama_inputs, llama_meta diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 427fc8e8a74..4f5631b6159 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -805,10 +805,6 @@ def _qmode_type(value): def _validate_args(llm_config): - """ - TODO: Combine all the backends under --backend args - """ - if llm_config.export.max_context_length < llm_config.export.max_seq_length: raise ValueError( f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." @@ -1498,9 +1494,9 @@ def _get_source_transforms( # noqa return transforms -def get_llama_model(args): - _validate_args(args) - e_mgr = _prepare_for_llama_export(args) +def get_llama_model(llm_config: LlmConfig): + _validate_args(llm_config) + e_mgr = _prepare_for_llama_export(llm_config) model = ( e_mgr.model.eval().to(device="cuda") if torch.cuda.is_available()