From 33ce0cf9c005e584d4b9cfbb1f4eb1e2277a3155 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Fri, 22 Nov 2024 11:09:27 -0800 Subject: [PATCH 1/3] add support to evalulate the model with attention sink Differential Revision: [D66378384](https://our.internmc.facebook.com/intern/diff/D66378384/) [ghstack-poisoned] --- examples/models/llama/eval_llama_lib.py | 63 +++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index dd01365ba59..48a00d18449 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -155,6 +155,44 @@ def _model_call(self, inps): pass +class AttentionSinkEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for evaluating the model with attention sink. + """ + + def __init__( + self, + model: torch.nn.Module, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + sink_size: int, + window_size: int, + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + ): + super().__init__(model, tokenizer, max_seq_length, use_kv_cache) + self.cache_size = sink_size + window_size + assert self._use_kv_cache, "Attention sink only works with kv cache." + + def _model_call(self, inps): + # Given inps (tokens), return the logits + + # Example: + # inps: Tensor of shape (1, N) + # logits: Tensor of shape (1, N, vocab_size) + _, seq_len = inps.shape + result = self._model( + inps[:, : min(seq_len, self.cache_size)], + torch.tensor([0], dtype=torch.int64, device=self.device), + ) + for pos in range(min(seq_len, self.cache_size), seq_len): + logits = self._model( + inps[:, pos : pos + 1], + torch.tensor([pos], dtype=torch.int64, device=self.device), + ) + result = torch.cat(result, logits[:, -1, :], dim=1) + return result + + def gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, @@ -225,6 +263,25 @@ def gen_eval_wrapper( if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) + if (use_attention_sink := args.use_attention_sink) is not None and ( + attention_sink_eval_length := args.attention_sink_eval_length + ) is not None: # pyre-ignore + attention_sink_params = use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + assert args.max_seq_length == sink_size + window_size + + return AttentionSinkEvalWrapper( + model=model, + tokenizer=tokenizer, + sink_size=sink_size, + window_size=window_size, + max_seq_length=attention_sink_eval_length, + use_kv_cache=args.use_kv_cache, + ) + return EagerEvalWrapper( model=model, tokenizer=tokenizer, @@ -279,6 +336,12 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) + parser.add_argument( + "--attention_sink_eval_length", + type=int, + default=2048, + help="The maximum length of the sequence to evaluate with attention sink.", + ) return parser From 8fe394a9f4dd31ac4e51fb5f1e4d0f98f584337d Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Fri, 22 Nov 2024 15:52:15 -0800 Subject: [PATCH 2/3] Update on "add support to evalulate the model with attention sink" Differential Revision: [D66378384](https://our.internmc.facebook.com/intern/diff/D66378384/) [ghstack-poisoned] --- examples/models/llama/eval_llama_lib.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 48a00d18449..f933b5dfb8f 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -6,6 +6,7 @@ import argparse +import copy from typing import Optional, Union @@ -166,11 +167,13 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken], sink_size: int, window_size: int, + eviction_batch_size: int, max_seq_length: Optional[int] = None, use_kv_cache: bool = False, ): super().__init__(model, tokenizer, max_seq_length, use_kv_cache) self.cache_size = sink_size + window_size + self.eviction_batch_size = eviction_batch_size assert self._use_kv_cache, "Attention sink only works with kv cache." def _model_call(self, inps): @@ -179,17 +182,18 @@ def _model_call(self, inps): # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) + model = copy.deepcopy(self._model) _, seq_len = inps.shape - result = self._model( + result = model( inps[:, : min(seq_len, self.cache_size)], torch.tensor([0], dtype=torch.int64, device=self.device), ) - for pos in range(min(seq_len, self.cache_size), seq_len): - logits = self._model( - inps[:, pos : pos + 1], + for pos in range(min(seq_len, self.cache_size), seq_len, self.eviction_batch_size): + logits = model( + inps[:, pos : pos + self.eviction_batch_size], torch.tensor([pos], dtype=torch.int64, device=self.device), ) - result = torch.cat(result, logits[:, -1, :], dim=1) + result = torch.cat((result, logits[:, -self.eviction_batch_size:, :]), dim=1) return result @@ -270,6 +274,7 @@ def gen_eval_wrapper( assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) + eviction_batch_size = int(attention_sink_params[2]) assert args.max_seq_length == sink_size + window_size From 1e1d9c8165052f566b0bee649534d697e27665ca Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Fri, 22 Nov 2024 17:05:25 -0800 Subject: [PATCH 3/3] Update on "add support to evalulate the model with attention sink" Differential Revision: [D66378384](https://our.internmc.facebook.com/intern/diff/D66378384/) [ghstack-poisoned] --- examples/models/llama/eval_llama_lib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index f933b5dfb8f..b0544064aa4 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -283,6 +283,7 @@ def gen_eval_wrapper( tokenizer=tokenizer, sink_size=sink_size, window_size=window_size, + eviction_batch_size=eviction_batch_size, max_seq_length=attention_sink_eval_length, use_kv_cache=args.use_kv_cache, )