Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,10 @@ def get_serialized_buffer_index(
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
if external_tag is not None:
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
29 changes: 29 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter_checkpoint",
required=False,
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
)

parser.add_argument(
"--adapter_config",
required=False,
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down Expand Up @@ -631,6 +643,17 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
)
params_path = canonical_path(args.params) if args.params else None

assert (args.adapter_checkpoint is None and args.adapter_config is None) or (
args.adapter_checkpoint is not None and args.adapter_config is not None
), "Must provide both adapter_checkpoint and adapter_config, or neither"
adapter_checkpoint_path = (
canonical_path(args.adapter_checkpoint) if args.adapter_checkpoint else None
)
adapter_config_path = (
canonical_path(args.adapter_config) if args.adapter_config else None
)

output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA

Expand All @@ -642,6 +665,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
adapter_checkpoint=adapter_checkpoint_path,
adapter_config=adapter_config_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
Expand Down Expand Up @@ -1141,6 +1166,8 @@ def _load_llama_model(
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
params_path: Optional[str] = None,
adapter_checkpoint: Optional[str] = None,
adapter_config: Optional[str] = None,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = False,
Expand Down Expand Up @@ -1188,6 +1215,8 @@ def _load_llama_model(
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
adapter_checkpoint=adapter_checkpoint,
adapter_config=adapter_config,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
Expand Down
21 changes: 20 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import json
import os
from typing import Dict, Tuple

import torch
from executorch.examples.models.checkpoint import (
Expand Down Expand Up @@ -47,6 +46,10 @@ def __init__(self, **kwargs):
# Params file.
params_path = kwargs.get("params", None)

# Adapter
adapter_checkpoint = kwargs.get("adapter_checkpoint", None)
adapter_config = kwargs.get("adapter_config", None)

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
Expand Down Expand Up @@ -130,6 +133,21 @@ def __init__(self, **kwargs):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
adapter_checkpoint = {}
adapter_config = {}
adapter_checkpoint_path = kwargs.get("adapter_checkpoint", None)
if adapter_checkpoint_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use os.exists or similar from Path?

adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
adapter_config = kwargs.get("adapter_config", None)
with open(adapter_config, "r") as f:
adapter_config = json.loads(f.read())
checkpoint.update(adapter_checkpoint)

output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
Expand All @@ -154,6 +172,7 @@ def __init__(self, **kwargs):
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down
Loading