| 
10 | 10 | from typing import Optional, Union  | 
11 | 11 | 
 
  | 
12 | 12 | import torch  | 
 | 13 | + | 
 | 14 | +from datasets import load_dataset  | 
13 | 15 | from executorch.examples.models.llama.export_llama_lib import (  | 
14 | 16 |     get_quantizer_and_quant_params,  | 
15 | 17 | )  | 
 | 
21 | 23 | )  | 
22 | 24 | from executorch.extension.llm.tokenizer.utils import get_tokenizer  | 
23 | 25 | from lm_eval.evaluator import simple_evaluate  | 
 | 26 | +from torch.nn import CrossEntropyLoss  | 
 | 27 | +from tqdm import tqdm  | 
24 | 28 | 
 
  | 
25 | 29 | from .evaluate.eager_eval import EagerEvalWrapper  | 
26 | 30 | 
 
  | 
@@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:  | 
280 | 284 |         help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",  | 
281 | 285 |     )  | 
282 | 286 | 
 
  | 
 | 287 | +    # Set of parameters secpific to AttentionSink.  | 
 | 288 | +    parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)  | 
 | 289 | + | 
283 | 290 |     return parser  | 
284 | 291 | 
 
  | 
285 | 292 | 
 
  | 
@@ -309,3 +316,60 @@ def eval_llama(  | 
309 | 316 | 
 
  | 
310 | 317 |     for task, res in eval_results["results"].items():  | 
311 | 318 |         print(f"{task}: {res}")  | 
 | 319 | + | 
 | 320 | + | 
 | 321 | +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):  | 
 | 322 | +    """  | 
 | 323 | +    Evaluate the model's perplexity when AttentionSink is enabled.  | 
 | 324 | +
  | 
 | 325 | +    This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py  | 
 | 326 | +    """  | 
 | 327 | +    assert args.use_attention_sink is not None  # pyre-ignore [16]  | 
 | 328 | +    assert args.attention_sink_eval_tokens > 0  # pyre-ignore [16]  | 
 | 329 | +    attention_sink_params = args.use_attention_sink.split(",")  | 
 | 330 | +    assert len(attention_sink_params) == 3  | 
 | 331 | +    sink_size = int(attention_sink_params[0])  | 
 | 332 | +    window_size = int(attention_sink_params[1])  | 
 | 333 | + | 
 | 334 | +    assert args.max_seq_length == sink_size + window_size  # pyre-ignore [16]  | 
 | 335 | + | 
 | 336 | +    device = "cuda" if torch.cuda.is_available() else "cpu"  | 
 | 337 | +    manager: LLMEdgeManager = _prepare_for_llama_export(args)  | 
 | 338 | +    model = manager.model.eval().to(device=device)  | 
 | 339 | +    tokenizer = get_tokenizer(args.tokenizer_path)  # pyre-ignore [16]  | 
 | 340 | + | 
 | 341 | +    eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")  | 
 | 342 | + | 
 | 343 | +    nlls = []  | 
 | 344 | +    loss_fn = CrossEntropyLoss(reduction="none")  | 
 | 345 | +    progress_bar = tqdm(total=args.attention_sink_eval_tokens)  | 
 | 346 | +    input_pos = 0  | 
 | 347 | +    while input_pos < args.attention_sink_eval_tokens:  | 
 | 348 | +        for text in eval_data["text"]:  # pyre-ignore [16]  | 
 | 349 | +            tokens = tokenizer.encode(text, bos=False, eos=False)  | 
 | 350 | +            if len(tokens) <= 0:  | 
 | 351 | +                continue  | 
 | 352 | +            with torch.no_grad():  | 
 | 353 | +                num_tokens = min(  | 
 | 354 | +                    len(tokens) - 1, args.attention_sink_eval_tokens - input_pos  | 
 | 355 | +                )  | 
 | 356 | +                logits = model(  | 
 | 357 | +                    torch.tensor(  | 
 | 358 | +                        [tokens[:num_tokens]], dtype=torch.int64, device=device  | 
 | 359 | +                    ),  | 
 | 360 | +                    torch.tensor([input_pos], dtype=torch.int64, device=device),  | 
 | 361 | +                ).squeeze(dim=0)  | 
 | 362 | +                neg_log_likelihood = loss_fn(  | 
 | 363 | +                    logits,  | 
 | 364 | +                    torch.tensor(  | 
 | 365 | +                        [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device  | 
 | 366 | +                    ).view(-1),  | 
 | 367 | +                )  | 
 | 368 | +                nlls.append(neg_log_likelihood)  | 
 | 369 | +                input_pos += num_tokens  | 
 | 370 | +                progress_bar.update(num_tokens)  | 
 | 371 | +            if input_pos >= args.attention_sink_eval_tokens:  | 
 | 372 | +                break  | 
 | 373 | +    ppl = torch.exp(torch.cat(nlls).mean())  | 
 | 374 | +    print(f"Perplexity: {ppl.item()}")  | 
 | 375 | +    return ppl.item()  | 
0 commit comments