|
10 | 10 | from typing import Optional, Union |
11 | 11 |
|
12 | 12 | import torch |
| 13 | + |
| 14 | +from datasets import load_dataset |
13 | 15 | from executorch.examples.models.llama.export_llama_lib import ( |
14 | 16 | get_quantizer_and_quant_params, |
15 | 17 | ) |
|
21 | 23 | ) |
22 | 24 | from executorch.extension.llm.tokenizer.utils import get_tokenizer |
23 | 25 | from lm_eval.evaluator import simple_evaluate |
| 26 | +from torch.nn import CrossEntropyLoss |
| 27 | +from tqdm import tqdm |
24 | 28 |
|
25 | 29 | from .evaluate.eager_eval import EagerEvalWrapper |
26 | 30 |
|
@@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser: |
280 | 284 | help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", |
281 | 285 | ) |
282 | 286 |
|
| 287 | + # Set of parameters secpific to AttentionSink. |
| 288 | + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) |
| 289 | + |
283 | 290 | return parser |
284 | 291 |
|
285 | 292 |
|
@@ -309,3 +316,60 @@ def eval_llama( |
309 | 316 |
|
310 | 317 | for task, res in eval_results["results"].items(): |
311 | 318 | print(f"{task}: {res}") |
| 319 | + |
| 320 | + |
| 321 | +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): |
| 322 | + """ |
| 323 | + Evaluate the model's perplexity when AttentionSink is enabled. |
| 324 | +
|
| 325 | + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py |
| 326 | + """ |
| 327 | + assert args.use_attention_sink is not None # pyre-ignore [16] |
| 328 | + assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] |
| 329 | + attention_sink_params = args.use_attention_sink.split(",") |
| 330 | + assert len(attention_sink_params) == 3 |
| 331 | + sink_size = int(attention_sink_params[0]) |
| 332 | + window_size = int(attention_sink_params[1]) |
| 333 | + |
| 334 | + assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] |
| 335 | + |
| 336 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 337 | + manager: LLMEdgeManager = _prepare_for_llama_export(args) |
| 338 | + model = manager.model.eval().to(device=device) |
| 339 | + tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] |
| 340 | + |
| 341 | + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| 342 | + |
| 343 | + nlls = [] |
| 344 | + loss_fn = CrossEntropyLoss(reduction="none") |
| 345 | + progress_bar = tqdm(total=args.attention_sink_eval_tokens) |
| 346 | + input_pos = 0 |
| 347 | + while input_pos < args.attention_sink_eval_tokens: |
| 348 | + for text in eval_data["text"]: # pyre-ignore [16] |
| 349 | + tokens = tokenizer.encode(text, bos=False, eos=False) |
| 350 | + if len(tokens) <= 0: |
| 351 | + continue |
| 352 | + with torch.no_grad(): |
| 353 | + num_tokens = min( |
| 354 | + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos |
| 355 | + ) |
| 356 | + logits = model( |
| 357 | + torch.tensor( |
| 358 | + [tokens[:num_tokens]], dtype=torch.int64, device=device |
| 359 | + ), |
| 360 | + torch.tensor([input_pos], dtype=torch.int64, device=device), |
| 361 | + ).squeeze(dim=0) |
| 362 | + neg_log_likelihood = loss_fn( |
| 363 | + logits, |
| 364 | + torch.tensor( |
| 365 | + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device |
| 366 | + ).view(-1), |
| 367 | + ) |
| 368 | + nlls.append(neg_log_likelihood) |
| 369 | + input_pos += num_tokens |
| 370 | + progress_bar.update(num_tokens) |
| 371 | + if input_pos >= args.attention_sink_eval_tokens: |
| 372 | + break |
| 373 | + ppl = torch.exp(torch.cat(nlls).mean()) |
| 374 | + print(f"Perplexity: {ppl.item()}") |
| 375 | + return ppl.item() |
0 commit comments