Skip to content

Commit 5cd90a1

Browse files
add eval for attention sink (#7150)
Pull Request resolved: #7070 This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled. ghstack-source-id: 256108079 @exported-using-ghexport Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/) Co-authored-by: Lunwen He <[email protected]>
1 parent 773813a commit 5cd90a1

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ runtime.python_library(
150150
"@EXECUTORCH_CLIENTS",
151151
],
152152
deps = [
153+
"fbsource//third-party/pypi/tqdm:tqdm",
154+
"fbsource//third-party/pypi/datasets:datasets",
153155
"fbsource//third-party/pypi/lm-eval:lm-eval",
154156
"fbsource//third-party/pypi/tiktoken:tiktoken",
155157
":export_library",

examples/models/llama/eval_llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import torch
1212

13-
from .eval_llama_lib import build_args_parser, eval_llama
13+
from .eval_llama_lib import (
14+
build_args_parser,
15+
eval_llama,
16+
eval_llama_with_attention_sink,
17+
)
1418

1519
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1620
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -24,7 +28,10 @@ def main() -> None:
2428
args = parser.parse_args()
2529
# Overrides this arg, because evaluation requires full logits.
2630
args.generate_full_logits = True
27-
eval_llama(modelname, args) # pyre-ignore
31+
if args.use_attention_sink:
32+
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
33+
else:
34+
eval_llama(modelname, args) # pyre-ignore
2835

2936

3037
if __name__ == "__main__":

examples/models/llama/eval_llama_lib.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Optional, Union
1111

1212
import torch
13+
14+
from datasets import load_dataset
1315
from executorch.examples.models.llama.export_llama_lib import (
1416
get_quantizer_and_quant_params,
1517
)
@@ -21,6 +23,8 @@
2123
)
2224
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2325
from lm_eval.evaluator import simple_evaluate
26+
from torch.nn import CrossEntropyLoss
27+
from tqdm import tqdm
2428

2529
from .evaluate.eager_eval import EagerEvalWrapper
2630

@@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
280284
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
281285
)
282286

287+
# Set of parameters secpific to AttentionSink.
288+
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
289+
283290
return parser
284291

285292

@@ -309,3 +316,60 @@ def eval_llama(
309316

310317
for task, res in eval_results["results"].items():
311318
print(f"{task}: {res}")
319+
320+
321+
def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
322+
"""
323+
Evaluate the model's perplexity when AttentionSink is enabled.
324+
325+
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
326+
"""
327+
assert args.use_attention_sink is not None # pyre-ignore [16]
328+
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
329+
attention_sink_params = args.use_attention_sink.split(",")
330+
assert len(attention_sink_params) == 3
331+
sink_size = int(attention_sink_params[0])
332+
window_size = int(attention_sink_params[1])
333+
334+
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
335+
336+
device = "cuda" if torch.cuda.is_available() else "cpu"
337+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
338+
model = manager.model.eval().to(device=device)
339+
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
340+
341+
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
342+
343+
nlls = []
344+
loss_fn = CrossEntropyLoss(reduction="none")
345+
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
346+
input_pos = 0
347+
while input_pos < args.attention_sink_eval_tokens:
348+
for text in eval_data["text"]: # pyre-ignore [16]
349+
tokens = tokenizer.encode(text, bos=False, eos=False)
350+
if len(tokens) <= 0:
351+
continue
352+
with torch.no_grad():
353+
num_tokens = min(
354+
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
355+
)
356+
logits = model(
357+
torch.tensor(
358+
[tokens[:num_tokens]], dtype=torch.int64, device=device
359+
),
360+
torch.tensor([input_pos], dtype=torch.int64, device=device),
361+
).squeeze(dim=0)
362+
neg_log_likelihood = loss_fn(
363+
logits,
364+
torch.tensor(
365+
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
366+
).view(-1),
367+
)
368+
nlls.append(neg_log_likelihood)
369+
input_pos += num_tokens
370+
progress_bar.update(num_tokens)
371+
if input_pos >= args.attention_sink_eval_tokens:
372+
break
373+
ppl = torch.exp(torch.cat(nlls).mean())
374+
print(f"Perplexity: {ppl.item()}")
375+
return ppl.item()

0 commit comments

Comments
 (0)