Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,10 @@ def get_serialized_buffer_index(
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
if external_tag is not None:
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
"lora.py",
"rope.py",
"attention.py",
"model_args.py",
Expand Down
74 changes: 66 additions & 8 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope
Expand Down Expand Up @@ -325,7 +326,20 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
):
"""
Multi-head attention layer.

Args:
args (ModelArgs): Model configuration parameters.
layer_id (int): Layer index.
rope (Rope): Rotary position embedding module.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
self.wq = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "q_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wk = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_kv_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "k_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wv = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_kv_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "v_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wo = (
LoRALinear(
in_dim=args.n_kv_heads * args.head_dim,
out_dim=args.dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "output_proj" in args.target_modules
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id

Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter_checkpoint",
required=False,
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
)

parser.add_argument(
"--adapter_config",
required=False,
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down
48 changes: 48 additions & 0 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class LoRALinear(nn.Module):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""

def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.dropout = dropout

linear = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear.weight
bias = linear.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out
22 changes: 22 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
checkpoint_dir = self.llm_config.base.checkpoint_dir
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
adapter_config_path = self.llm_config.base.adapter_config
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
adapter_checkpoint_path is not None and adapter_config_path is not None
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
self.generate_full_logits = self.llm_config.debug.generate_full_logits
Expand Down Expand Up @@ -129,6 +136,20 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights

adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
with open(adapter_config_path, "r") as f:
adapter_config = json.loads(f.read())
checkpoint.update(adapter_checkpoint)

output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
Expand All @@ -153,6 +174,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down
10 changes: 10 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ class ModelArgs:
eos_count: int = 2

quantization_args: Optional[dict] = None
# LoRA for QAT.
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune adapter_config.json file.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
target_modules: Optional[list] = None
peft_type: Optional[str] = None # PEFT type.
base_model_name_or_path: Optional[str] = None # Base model name or path.

def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
10 changes: 9 additions & 1 deletion extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ class BaseConfig:
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
checkpoint_dir: Path to directory containing sharded checkpoint files.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
adapter_config: Path to the adapter_config.json file from torchtune.
Used if the model has trained LoRA adapters. Must provide adapter.pt.
tokenizer_path: Path to the tokenizer file.
metadata: Json string containing metadata information.
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled
if set to 0.
fairseq2: For legacy internal use cases, this is safe to ignore.
preq_mode: Legacy option to specify how prequantized weights are loaded.
Going forward, ExecuTorch supports loading weights prequantized through
Expand All @@ -90,6 +96,8 @@ class BaseConfig:
params: Optional[str] = None
checkpoint: Optional[str] = None
checkpoint_dir: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
Expand Down
Loading