Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ runtime.python_library(
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/prune_output.py",
"source_transformation/quantize.py",
"source_transformation/rms_norm.py",
"source_transformation/rope.py",
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def build_args_parser() -> argparse.ArgumentParser:
choices=["cuda", "native"],
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)

parser.add_argument(
"--output_prune_map",
default=None,
help="path to the output pruning token mapping file (token_map.json)",
)
return parser


Expand Down Expand Up @@ -458,6 +464,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
args=args,
)
Expand Down Expand Up @@ -682,6 +689,7 @@ def _load_llama_model(
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
args,
) -> "LLMEdgeManager":
Expand Down Expand Up @@ -709,6 +717,7 @@ def _load_llama_model(
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
output_prune_map_path=output_prune_map_path,
args=args,
)
state_dict = model.state_dict()
Expand Down
28 changes: 27 additions & 1 deletion examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -102,6 +102,8 @@ class ModelArgs:
# logits for all input tokens.)
generate_full_logits: bool = False
enable_dynamic_shape: bool = False # export model with dynamic shape support
# A dictionary mapping from pruned token-id to original token-id
output_prune_map: Optional[Dict[int, int]] = None
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
Expand Down Expand Up @@ -449,6 +451,7 @@ def __init__(self, params: ModelArgs):
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.output_prune_map = params.output_prune_map
if params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
Expand Down Expand Up @@ -525,4 +528,27 @@ def forward(
h = self.norm(h)

logits = self.output(h)

if self.output_prune_map is not None:
# expand to original size so that downstream applications can use the logits as-is.
if self.generate_full_logits:
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
expanded_logits = torch.full(
[logits.shape[0], logits.shape[1], self.vocab_size],
float("-inf"),
device=logits.device,
dtype=logits.dtype,
)
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
else:
# (1, pruned_size) -> (1, original_size)
expanded_logits = torch.full(
[logits.shape[0], self.vocab_size],
float("-inf"),
device=logits.device,
dtype=logits.dtype,
)
expanded_logits[:, list(self.output_prune_map.values())] = logits
logits = expanded_logits

return logits
14 changes: 14 additions & 0 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, **kwargs):
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)

self.max_seq_len = kwargs.get("max_seq_len", 128)
self.args = kwargs.get("args", None)
Expand Down Expand Up @@ -141,6 +142,12 @@ def __init__(self, **kwargs):
)
with open(params_path, "r") as f:
params = json.loads(f.read())
output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
output_prune_map = json.load(f)
# change keys from string to int (json only supports string keys)
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
max_seq_len = self.max_seq_len
max_batch_size = 1
model_args: ModelArgs = ModelArgs(
Expand All @@ -149,6 +156,7 @@ def __init__(self, **kwargs):
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
generate_full_logits=self.generate_full_logits,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
)
Expand Down Expand Up @@ -230,6 +238,12 @@ def __init__(self, **kwargs):
print(unexpected)
print("============= /unexpected ================")

# prune the output layer if output_prune_map is provided
if output_prune_map is not None:
from .source_transformation.prune_output import prune_output_vocab

self.model_ = prune_output_vocab(self.model_, output_prune_map)

def get_eager_model(self):
if self.dtype:
# convert to the type of the provided checkpoint
Expand Down
71 changes: 71 additions & 0 deletions examples/models/llama2/source_transformation/prune_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import numpy as np

import torch


def prune_output_vocab(
model: torch.nn.Module,
token_map: Dict[int, int],
output_layer_name: str = "output",
) -> torch.nn.Module:
"""Prune the model output linear layer while keeping the tokens in the token map.

Note: Pruning is performed in-place.

Args:
model: The model to prune.
token_map: A dictionary mapping from new token ids to the old token ids to preserve.
e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
output_layer_name: name of the output layer to prune

Returns:
The pruned model.
"""
assert hasattr(
model, output_layer_name
), f"Model does not have {output_layer_name} layer"
output_layer = getattr(model, output_layer_name)
assert isinstance(
output_layer, torch.nn.Linear
), "Output layer is not a linear layer"
original_shape = output_layer.weight.shape
input_features = original_shape[1]
num_pruned_tokens = len(token_map)
has_bias = output_layer.bias is not None
weight_dtype = output_layer.weight.dtype
pruned_layer = torch.nn.Linear(input_features, num_pruned_tokens, bias=has_bias)
pruned_layer.to(dtype=weight_dtype)
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
pruned_layer_bias = None
if has_bias:
pruned_layer_bias = np.zeros(pruned_layer.bias.shape, dtype=np.float32)
for i, token_id in token_map.items():
# Copy the weights and biases from the original layer to the pruned layer
pruned_wt = output_layer.weight[token_id].detach()
if weight_dtype == torch.bfloat16:
pruned_wt = pruned_wt.float()
pruned_layer_weights[i] = pruned_wt.numpy()
if has_bias:
pruned_bias = output_layer.bias[token_id].detach()
if weight_dtype == torch.bfloat16:
pruned_bias = pruned_bias.float()
pruned_layer_bias[i] = pruned_bias.numpy()
with torch.no_grad():
pruned_layer.weight.copy_(
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
)
if has_bias:
pruned_layer.bias.copy_(torch.tensor(pruned_layer_bias, dtype=weight_dtype))

# Replace the original layer with the pruned layer
setattr(model, output_layer_name, pruned_layer)

return model
Loading