Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
Expand Down Expand Up @@ -89,8 +90,9 @@ def prepare_model(self):
]
parser = build_args_parser()
args = parser.parse_args(args)
llm_config = LlmConfig.from_args(args)

llama_model, llama_inputs, llama_meta = get_llama_model(args)
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)

return llama_model, llama_inputs, llama_meta

Expand Down
31 changes: 14 additions & 17 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
serialize_from_bundled_program_to_flatbuffer,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
Expand Down Expand Up @@ -131,28 +132,24 @@ def parse_args():
return args


def get_model_config(args):
model_config = {}
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]

if args.model_name == "llama2":
if args.checkpoint:
model_config["checkpoint"] = args.checkpoint
if args.params:
model_config["params"] = args.params
model_config["use_kv_cache"] = True
return model_config


if __name__ == "__main__":
if __name__ == "__main__": # noqa: C901
args = parse_args()

if args.model_name not in MODEL_NAME_TO_MODEL:
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")

model_config = get_model_config(args)
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
llm_config = LlmConfig()
if args.model_name == "llama2":
if args.checkpoint:
llm_config.base.checkpoint = args.checkpoint
if args.params:
llm_config.base.params = args.params
llm_config.model.use_kv_cache = True
model, example_inputs, _, _ = EagerModelFactory.create_model(
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
llm_config=llm_config,
)

model = model.eval()

Expand Down
5 changes: 5 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/examples/models:checkpoint",
],
)
Expand Down Expand Up @@ -132,6 +133,8 @@ runtime.python_library(
name = "export_library",
srcs = [
"export_llama.py",
"export_llama_args.py",
"export_llama_hydra.py",
"export_llama_lib.py",
"model.py",
],
Expand All @@ -148,6 +151,8 @@ runtime.python_library(
":source_transformation",
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama/config/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
Loading
Loading