Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,10 @@ def get_serialized_buffer_index(
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
if external_tag is not None:
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter_checkpoint",
required=False,
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
)

parser.add_argument(
"--adapter_config",
required=False,
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down
22 changes: 22 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
checkpoint_dir = self.llm_config.base.checkpoint_dir
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
adapter_config_path = self.llm_config.base.adapter_config
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
adapter_checkpoint_path is not None and adapter_config_path is not None
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
self.generate_full_logits = self.llm_config.debug.generate_full_logits
Expand Down Expand Up @@ -129,6 +136,20 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use os.exists or similar from Path?

adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights

adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
with open(adapter_config_path, "r") as f:
adapter_config = json.loads(f.read())
checkpoint.update(adapter_checkpoint)

output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
Expand All @@ -153,6 +174,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ModelArgs:
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune LoRA config.
# These arguments come directly from a torchtune adapter_config.json file.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
Expand Down
10 changes: 9 additions & 1 deletion extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ class BaseConfig:
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
checkpoint_dir: Path to directory containing sharded checkpoint files.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
adapter_config: Path to the adapter_config.json file from torchtune.
Used if the model has trained LoRA adapters. Must provide adapter.pt.
tokenizer_path: Path to the tokenizer file.
metadata: Json string containing metadata information.
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didnt add this, but why is this boolean named field an int, and why does it correspond with qat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cccclai ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't add it myself either, I think it's likely from Lunwen - I believe it's for the llama 3.2 1b QAT checkpoint which include LoRA, so make sure we don't break llama3.2 QAT model if we use this flag somewhere else

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if set to 0.
fairseq2: For legacy internal use cases, this is safe to ignore.
preq_mode: Legacy option to specify how prequantized weights are loaded.
Going forward, ExecuTorch supports loading weights prequantized through
Expand All @@ -90,6 +96,8 @@ class BaseConfig:
params: Optional[str] = None
checkpoint: Optional[str] = None
checkpoint_dir: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
Expand Down
Loading