diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index f1c56a5bda3..7ed858a33c5 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -71,6 +71,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/prune_output.py", "source_transformation/quantize.py", "source_transformation/rms_norm.py", "source_transformation/rope.py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 5cef72c1e6e..ff917ab5750 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -369,6 +369,12 @@ def build_args_parser() -> argparse.ArgumentParser: choices=["cuda", "native"], help="Use SpinQuant for better quantization performance. Only support cuda and native.", ) + + parser.add_argument( + "--output_prune_map", + default=None, + help="path to the output pruning token mapping file (token_map.json)", + ) return parser @@ -458,6 +464,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: tokenizer_path=args.tokenizer_path, verbose=args.verbose, max_seq_len=args.max_seq_length, + output_prune_map_path=args.output_prune_map, metadata_str=args.metadata, args=args, ) @@ -682,6 +689,7 @@ def _load_llama_model( tokenizer_path: Optional[str] = None, verbose: bool = False, max_seq_len: int = 128, + output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, args, ) -> "LLMEdgeManager": @@ -709,6 +717,7 @@ def _load_llama_model( fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, enable_dynamic_shape=enable_dynamic_shape, + output_prune_map_path=output_prune_map_path, args=args, ) state_dict = model.state_dict() diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 3c75b9c75f3..65090e2fe5a 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import partial -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn.functional as F @@ -102,6 +102,8 @@ class ModelArgs: # logits for all input tokens.) generate_full_logits: bool = False enable_dynamic_shape: bool = False # export model with dynamic shape support + # A dictionary mapping from pruned token-id to original token-id + output_prune_map: Optional[Dict[int, int]] = None use_hf_rope: bool = False # Use HuggingFace's RoPE implementation rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. @@ -449,6 +451,7 @@ def __init__(self, params: ModelArgs): self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len + self.output_prune_map = params.output_prune_map if params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis else: @@ -525,4 +528,27 @@ def forward( h = self.norm(h) logits = self.output(h) + + if self.output_prune_map is not None: + # expand to original size so that downstream applications can use the logits as-is. + if self.generate_full_logits: + # (1, seq_len, pruned_size) -> (1, seq_len, original_size) + expanded_logits = torch.full( + [logits.shape[0], logits.shape[1], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, :, list(self.output_prune_map.values())] = logits + else: + # (1, pruned_size) -> (1, original_size) + expanded_logits = torch.full( + [logits.shape[0], self.vocab_size], + float("-inf"), + device=logits.device, + dtype=logits.dtype, + ) + expanded_logits[:, list(self.output_prune_map.values())] = logits + logits = expanded_logits + return logits diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 174f562f93a..21714a9c159 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -63,6 +63,7 @@ def __init__(self, **kwargs): self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) self.generate_full_logits = kwargs.get("generate_full_logits", False) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) + self.output_prune_map_path = kwargs.get("output_prune_map_path", None) self.max_seq_len = kwargs.get("max_seq_len", 128) self.args = kwargs.get("args", None) @@ -141,6 +142,12 @@ def __init__(self, **kwargs): ) with open(params_path, "r") as f: params = json.loads(f.read()) + output_prune_map = None + if self.output_prune_map_path is not None: + with open(self.output_prune_map_path, "r") as f: + output_prune_map = json.load(f) + # change keys from string to int (json only supports string keys) + output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} max_seq_len = self.max_seq_len max_batch_size = 1 model_args: ModelArgs = ModelArgs( @@ -149,6 +156,7 @@ def __init__(self, **kwargs): use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, generate_full_logits=self.generate_full_logits, + output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, **params, ) @@ -230,6 +238,12 @@ def __init__(self, **kwargs): print(unexpected) print("============= /unexpected ================") + # prune the output layer if output_prune_map is provided + if output_prune_map is not None: + from .source_transformation.prune_output import prune_output_vocab + + self.model_ = prune_output_vocab(self.model_, output_prune_map) + def get_eager_model(self): if self.dtype: # convert to the type of the provided checkpoint diff --git a/examples/models/llama2/source_transformation/prune_output.py b/examples/models/llama2/source_transformation/prune_output.py new file mode 100644 index 00000000000..6d02d52fa5c --- /dev/null +++ b/examples/models/llama2/source_transformation/prune_output.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import numpy as np + +import torch + + +def prune_output_vocab( + model: torch.nn.Module, + token_map: Dict[int, int], + output_layer_name: str = "output", +) -> torch.nn.Module: + """Prune the model output linear layer while keeping the tokens in the token map. + + Note: Pruning is performed in-place. + + Args: + model: The model to prune. + token_map: A dictionary mapping from new token ids to the old token ids to preserve. + e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243} + output_layer_name: name of the output layer to prune + + Returns: + The pruned model. + """ + assert hasattr( + model, output_layer_name + ), f"Model does not have {output_layer_name} layer" + output_layer = getattr(model, output_layer_name) + assert isinstance( + output_layer, torch.nn.Linear + ), "Output layer is not a linear layer" + original_shape = output_layer.weight.shape + input_features = original_shape[1] + num_pruned_tokens = len(token_map) + has_bias = output_layer.bias is not None + weight_dtype = output_layer.weight.dtype + pruned_layer = torch.nn.Linear(input_features, num_pruned_tokens, bias=has_bias) + pruned_layer.to(dtype=weight_dtype) + pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32) + pruned_layer_bias = None + if has_bias: + pruned_layer_bias = np.zeros(pruned_layer.bias.shape, dtype=np.float32) + for i, token_id in token_map.items(): + # Copy the weights and biases from the original layer to the pruned layer + pruned_wt = output_layer.weight[token_id].detach() + if weight_dtype == torch.bfloat16: + pruned_wt = pruned_wt.float() + pruned_layer_weights[i] = pruned_wt.numpy() + if has_bias: + pruned_bias = output_layer.bias[token_id].detach() + if weight_dtype == torch.bfloat16: + pruned_bias = pruned_bias.float() + pruned_layer_bias[i] = pruned_bias.numpy() + with torch.no_grad(): + pruned_layer.weight.copy_( + torch.tensor(pruned_layer_weights, dtype=weight_dtype) + ) + if has_bias: + pruned_layer.bias.copy_(torch.tensor(pruned_layer_bias, dtype=weight_dtype)) + + # Replace the original layer with the pruned layer + setattr(model, output_layer_name, pruned_layer) + + return model