From ca5cefa2a7fdf34cb43ac6a5813031c7073c0189 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 2 Dec 2024 13:01:16 -0800 Subject: [PATCH 1/3] Transform model to be able to use Attention Sink Pull Request resolved: https://github.com/pytorch/executorch/pull/6700 This PR adds necessary functions for transforming the model to be able to use Attention Sink. ghstack-source-id: 256108077 @exported-using-ghexport Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) --- examples/models/llama/export_llama_lib.py | 7 ++ examples/models/llama/model.py | 19 +++ .../source_transformation/attention_sink.py | 118 +++++++++++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 9a290968a35..ea4296cc52c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) + parser.add_argument( + "--use_attention_sink", + default=None, + type=str, + help="Use attention sink to have fluent multi-round conversation. ',,', e.g., '4,2044,1024'.", + ) + parser.add_argument( "--output_prune_map", default=None, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 0f83e404a3c..2385aba6d5d 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -201,6 +201,25 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) + if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + from .source_transformation.attention_sink import enable_attention_sink + + attention_sink_params = self.args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + eviction_batch_size = int(attention_sink_params[2]) + + assert self.args.max_seq_length == sink_size + window_size + + self.model_ = enable_attention_sink( + module=self.model_, + params=model_args, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8450600d2b1..b534a98e078 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,15 +7,22 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +import types from typing import Optional import torch -from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import ( + Attention, + KVCache, + ModelArgs, + Rope, +) from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, ) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter class RopeWithAttentionSink(Rope): @@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: ) self.position_shift -= num_to_evict # pyre-ignore [8] return self.position_shift + + +def attention_sink_forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # Prepare for space in KV cache and get position shift + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) + + # RoPE relative positional embeddings with shifted position in KV cache + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + for _, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention( + module=child_module, # pyre-ignore [6] + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + + if isinstance(child_module, Attention): + kv_cache = child_module.kv_cache + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + transpose_cache=kv_cache.transpose_cache, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + child_module.SDPA.kv_cache = kv_cache_with_attention_sink + child_module.forward = types.MethodType( # pyre-ignore + attention_sink_forward, child_module + ) + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int, + window_size: int, + eviction_batch_size: int, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + _replace_rope(module, rope_with_attention_sink) + _replace_attention( + module=module, + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + return module From 773813ad5fd1a8265b2457ab847a8d17aef788b7 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 2 Dec 2024 13:01:17 -0800 Subject: [PATCH 2/3] Update eager runner to support AttentionSink Pull Request resolved: https://github.com/pytorch/executorch/pull/6921 This PR updates the eager runner to support AttentionSink. It also fixes issues in the `chat_completion` function to properly handle the position id. ghstack-source-id: 256108078 Differential Revision: [D66076486](https://our.internmc.facebook.com/intern/diff/D66076486/) --- examples/models/llama/runner/eager.py | 6 ++++- examples/models/llama/runner/generation.py | 27 ++++++++++++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7b4ebf36a56..559b4e04892 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: with torch.no_grad(): runner = runner_class(args) # pyre-ignore: Missing argument [20] generated_tokens = ( - runner.chat_completion(temperature=args.temperature) + runner.chat_completion( + max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, + temperature=args.temperature, + show_progress=args.show_tokens, + ) if args.chat else runner.text_completion( prompt=args.prompt, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 13ac750305f..891ce20db3e 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -168,18 +168,19 @@ def text_completion( def chat_completion( self, + max_seq_len: int, temperature: float = 0.6, top_p: float = 0.9, + show_progress: bool = False, ) -> List[int]: """ Perform multi-turn chat with the language model. Args: - prompt (str): Text prompt for completion. + max_seq_len (int): Maximum number of tokens to generate for each prompt. temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. - + show_progress (bool, optional): Flag indicating whether to show number of tokens generated. Returns: Generated list of tokens. @@ -188,20 +189,26 @@ def chat_completion( """ exit_prompt = "exit" tokens = [] + pre_stop_token = [] prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) - new_tokens = self.generate( - prompt_tokens=self.tokenizer.encode( - self._format_prompt(prompt), bos=True, eos=False - ), - max_seq_len=self.max_seq_len, + prompt_tokens = self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ) + generated_tokens = self.generate( + prompt_tokens=pre_stop_token + prompt_tokens, + max_seq_len=max_seq_len, temperature=temperature, top_p=top_p, - echo=True, + echo=False, pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, ) - tokens.extend(new_tokens) + pre_stop_token = generated_tokens[-1:] + tokens.extend(prompt_tokens) + tokens.extend(generated_tokens) + if show_progress: + print(f"[Generated {len(tokens)} tokens]") prompt = input("Me: ") return tokens From 5cd90a169a0974efd929b1bb9026b9aec3a8369d Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 2 Dec 2024 17:24:42 -0800 Subject: [PATCH 3/3] add eval for attention sink (#7150) Pull Request resolved: https://github.com/pytorch/executorch/pull/7070 This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled. ghstack-source-id: 256108079 @exported-using-ghexport Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/) Co-authored-by: Lunwen He --- examples/models/llama/TARGETS | 2 + examples/models/llama/eval_llama.py | 11 ++++- examples/models/llama/eval_llama_lib.py | 64 +++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 284520d4d5e..445bcd673bf 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -150,6 +150,8 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "fbsource//third-party/pypi/tqdm:tqdm", + "fbsource//third-party/pypi/datasets:datasets", "fbsource//third-party/pypi/lm-eval:lm-eval", "fbsource//third-party/pypi/tiktoken:tiktoken", ":export_library", diff --git a/examples/models/llama/eval_llama.py b/examples/models/llama/eval_llama.py index 09157789bde..7c959d08b9b 100644 --- a/examples/models/llama/eval_llama.py +++ b/examples/models/llama/eval_llama.py @@ -10,7 +10,11 @@ import torch -from .eval_llama_lib import build_args_parser, eval_llama +from .eval_llama_lib import ( + build_args_parser, + eval_llama, + eval_llama_with_attention_sink, +) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -24,7 +28,10 @@ def main() -> None: args = parser.parse_args() # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True - eval_llama(modelname, args) # pyre-ignore + if args.use_attention_sink: + eval_llama_with_attention_sink(modelname, args) # pyre-ignore + else: + eval_llama(modelname, args) # pyre-ignore if __name__ == "__main__": diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index dd01365ba59..a7f0f88cd9a 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -10,6 +10,8 @@ from typing import Optional, Union import torch + +from datasets import load_dataset from executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params, ) @@ -21,6 +23,8 @@ ) from executorch.extension.llm.tokenizer.utils import get_tokenizer from lm_eval.evaluator import simple_evaluate +from torch.nn import CrossEntropyLoss +from tqdm import tqdm from .evaluate.eager_eval import EagerEvalWrapper @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser: help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) + # Set of parameters secpific to AttentionSink. + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) + return parser @@ -309,3 +316,60 @@ def eval_llama( for task, res in eval_results["results"].items(): print(f"{task}: {res}") + + +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): + """ + Evaluate the model's perplexity when AttentionSink is enabled. + + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + """ + assert args.use_attention_sink is not None # pyre-ignore [16] + assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] + attention_sink_params = args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + + device = "cuda" if torch.cuda.is_available() else "cpu" + manager: LLMEdgeManager = _prepare_for_llama_export(args) + model = manager.model.eval().to(device=device) + tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + nlls = [] + loss_fn = CrossEntropyLoss(reduction="none") + progress_bar = tqdm(total=args.attention_sink_eval_tokens) + input_pos = 0 + while input_pos < args.attention_sink_eval_tokens: + for text in eval_data["text"]: # pyre-ignore [16] + tokens = tokenizer.encode(text, bos=False, eos=False) + if len(tokens) <= 0: + continue + with torch.no_grad(): + num_tokens = min( + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos + ) + logits = model( + torch.tensor( + [tokens[:num_tokens]], dtype=torch.int64, device=device + ), + torch.tensor([input_pos], dtype=torch.int64, device=device), + ).squeeze(dim=0) + neg_log_likelihood = loss_fn( + logits, + torch.tensor( + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device + ).view(-1), + ) + nlls.append(neg_log_likelihood) + input_pos += num_tokens + progress_bar.update(num_tokens) + if input_pos >= args.attention_sink_eval_tokens: + break + ppl = torch.exp(torch.cat(nlls).mean()) + print(f"Perplexity: {ppl.item()}") + return ppl.item()