Skip to content

Commit a2cc4aa

Browse files
committed
add eval for attention sink
This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled. Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/) [ghstack-poisoned]
1 parent 64d3437 commit a2cc4aa

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ runtime.python_library(
150150
"@EXECUTORCH_CLIENTS",
151151
],
152152
deps = [
153+
"fbsource//third-party/pypi/tqdm:tqdm",
154+
"fbsource//third-party/pypi/datasets:datasets",
153155
"fbsource//third-party/pypi/lm-eval:lm-eval",
154156
"fbsource//third-party/pypi/tiktoken:tiktoken",
155157
":export_library",

examples/models/llama/eval_llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import torch
1212

13-
from .eval_llama_lib import build_args_parser, eval_llama
13+
from .eval_llama_lib import (
14+
build_args_parser,
15+
eval_llama,
16+
eval_llama_with_attention_sink,
17+
)
1418

1519
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1620
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -24,7 +28,10 @@ def main() -> None:
2428
args = parser.parse_args()
2529
# Overrides this arg, because evaluation requires full logits.
2630
args.generate_full_logits = True
27-
eval_llama(modelname, args) # pyre-ignore
31+
if args.use_attention_sink:
32+
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
33+
else:
34+
eval_llama(modelname, args) # pyre-ignore
2835

2936

3037
if __name__ == "__main__":

examples/models/llama/eval_llama_lib.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Optional, Union
1111

1212
import torch
13+
14+
from datasets import load_dataset
1315
from executorch.examples.models.llama.export_llama_lib import (
1416
get_quantizer_and_quant_params,
1517
)
@@ -21,6 +23,8 @@
2123
)
2224
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2325
from lm_eval.evaluator import simple_evaluate
26+
from torch.nn import CrossEntropyLoss
27+
from tqdm import tqdm
2428

2529
from .evaluate.eager_eval import EagerEvalWrapper
2630

@@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
280284
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
281285
)
282286

287+
# Set of parameters secpific to AttentionSink.
288+
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
289+
283290
return parser
284291

285292

@@ -309,3 +316,61 @@ def eval_llama(
309316

310317
for task, res in eval_results["results"].items():
311318
print(f"{task}: {res}")
319+
320+
321+
def eval_llama_with_attention_sink(
322+
model_name: str, args: argparse.ArgumentParser
323+
) -> None:
324+
"""
325+
Evaluate the model's perplexity when AttentionSink is enabled.
326+
327+
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328+
"""
329+
assert args.use_attention_sink is not None # pyre-ignore [16]
330+
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
331+
attention_sink_params = args.use_attention_sink.split(",") # pyre-ignore [16]
332+
assert len(attention_sink_params) == 3
333+
sink_size = int(attention_sink_params[0])
334+
window_size = int(attention_sink_params[1])
335+
336+
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
337+
338+
device = "cuda" if torch.cuda.is_available() else "cpu"
339+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
340+
model = manager.model.eval().to(device=device)
341+
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
342+
343+
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
344+
345+
nlls = []
346+
loss_fn = CrossEntropyLoss(reduction="none")
347+
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
348+
input_pos = 0
349+
while input_pos < args.attention_sink_eval_tokens:
350+
for text in eval_data["text"]: # pyre-ignore [16]
351+
tokens = tokenizer.encode(text, bos=False, eos=False)
352+
if len(tokens) <= 0:
353+
continue
354+
with torch.no_grad():
355+
num_tokens = min(
356+
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
357+
)
358+
logits = model(
359+
torch.tensor(
360+
[tokens[:num_tokens]], dtype=torch.int64, device=device
361+
),
362+
torch.tensor([input_pos], dtype=torch.int64, device=device),
363+
).squeeze(dim=0)
364+
neg_log_likelihood = loss_fn(
365+
logits,
366+
torch.tensor(
367+
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
368+
).view(-1),
369+
)
370+
nlls.append(neg_log_likelihood)
371+
input_pos += num_tokens
372+
progress_bar.update(num_tokens)
373+
if input_pos >= args.attention_sink_eval_tokens:
374+
break
375+
ppl = torch.exp(torch.cat(nlls).mean())
376+
print(f"Perplexity: {ppl.item()}")

0 commit comments

Comments
 (0)