Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
"""

try:
import importlib
import pkg_resources

# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
# pyre-ignore
model_name = Path(model_file_path).parent.name
from executorch.examples.models.llama import params # noqa
module = importlib.import_module(f"executorch.examples.models.{model_name}.params")
# params = module.params

# Get the model name from the cwd, assuming that this module is called from a path such as
# examples/models/<model_name>/model.py.
model_name = Path(model_file_path).parent.name
resource_dir = Path(
pkg_resources.resource_filename(
f"executorch.examples.models.{model_name}", "params"
Expand Down
7 changes: 3 additions & 4 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-c",
"--checkpoint",
default=f"{ckpt_dir}/params/demo_rand_params.pth",
help="checkpoint path",
)

Expand Down Expand Up @@ -874,9 +873,9 @@ def _load_llama_model(
An instance of LLMEdgeManager which contains the eager mode model.
"""

assert (
checkpoint or checkpoint_dir
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
# assert (
# checkpoint or checkpoint_dir
# ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
logging.info(
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, **kwargs):

ckpt_dir = get_default_model_resource_dir(__file__)
# Single checkpoint file.
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
checkpoint_path = kwargs.get("checkpoint") if kwargs.get("checkpoint") else ckpt_dir / "demo_rand_params.pth"
if os.path.isfile(checkpoint_path):
self.use_checkpoint = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"multiple_of": 1024,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 32,
"n_layers": 1,
"n_special_tokens": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
Expand Down
Loading