Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import argparse
import copy

from typing import Optional, Union

Expand Down Expand Up @@ -155,6 +156,47 @@ def _model_call(self, inps):
pass


class AttentionSinkEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for evaluating the model with attention sink.
"""

def __init__(
self,
model: torch.nn.Module,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
sink_size: int,
window_size: int,
eviction_batch_size: int,
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
):
super().__init__(model, tokenizer, max_seq_length, use_kv_cache)
self.cache_size = sink_size + window_size
self.eviction_batch_size = eviction_batch_size
assert self._use_kv_cache, "Attention sink only works with kv cache."

def _model_call(self, inps):
# Given inps (tokens), return the logits

# Example:
# inps: Tensor of shape (1, N)
# logits: Tensor of shape (1, N, vocab_size)
model = copy.deepcopy(self._model)
_, seq_len = inps.shape
result = model(
inps[:, : min(seq_len, self.cache_size)],
torch.tensor([0], dtype=torch.int64, device=self.device),
)
for pos in range(min(seq_len, self.cache_size), seq_len, self.eviction_batch_size):
logits = model(
inps[:, pos : pos + self.eviction_batch_size],
torch.tensor([pos], dtype=torch.int64, device=self.device),
)
result = torch.cat((result, logits[:, -self.eviction_batch_size:, :]), dim=1)
return result


def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
Expand Down Expand Up @@ -225,6 +267,27 @@ def gen_eval_wrapper(
if args.output_eager_checkpoint_file is not None: # pyre-ignore
torch.save(model, args.output_eager_checkpoint_file)

if (use_attention_sink := args.use_attention_sink) is not None and (
attention_sink_eval_length := args.attention_sink_eval_length
) is not None: # pyre-ignore
attention_sink_params = use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert args.max_seq_length == sink_size + window_size

return AttentionSinkEvalWrapper(
model=model,
tokenizer=tokenizer,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
max_seq_length=attention_sink_eval_length,
use_kv_cache=args.use_kv_cache,
)

return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -279,6 +342,12 @@ def build_args_parser() -> argparse.ArgumentParser:
default=None,
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)
parser.add_argument(
"--attention_sink_eval_length",
type=int,
default=2048,
help="The maximum length of the sequence to evaluate with attention sink.",
)

return parser

Expand Down
Loading