Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/models/llama/eval_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from .eval_llama_lib import (
_convert_cli_to_config_format,
build_args_parser,
eval_llama,
eval_llama_with_attention_sink,
Expand All @@ -28,10 +29,11 @@ def main() -> None:
args = parser.parse_args()
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True
config = _convert_cli_to_config_format(args)
if args.use_attention_sink:
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
eval_llama_with_attention_sink(modelname, config) # pyre-ignore
else:
eval_llama(modelname, args) # pyre-ignore
eval_llama(modelname, config) # pyre-ignore


if __name__ == "__main__":
Expand Down
209 changes: 96 additions & 113 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import argparse

from typing import Optional, Union

import torch

from datasets import load_dataset
from executorch.examples.models.llama.export_llama_lib import (
_convert_args_to_config,
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken

from executorch.extension.llm.export.builder import LLMEdgeManager
from executorch.extension.llm.tokenizer.tokenizer import (
Tokenizer as SentencePieceTokenizer,
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.evaluator import simple_evaluate
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf

from .evaluate.eager_eval import EagerEvalWrapper

from .export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
)


class GraphModuleEvalWrapper(EagerEvalWrapper):
"""
Expand Down Expand Up @@ -165,7 +157,7 @@ def _model_call(self, inps):

def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
config: DictConfig,
):
"""
Generates a wrapper interface around the provided model and tokenizer for
Expand All @@ -174,17 +166,17 @@ def gen_eval_wrapper(
Returns:
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
"""
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
tokenizer = get_tokenizer(config.export.tokenizer_path)

# ExecuTorch Binary Evaluation
if (model := args.pte) is not None: # pyre-ignore
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
if (model := config.eval.pte) is not None:
if (tokenizer_bin := config.eval.tokenizer_bin) is not None:
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
return ETRunnerEvalWrapper(
model=model,
tokenizer=tokenizer,
tokenizer_bin=tokenizer_bin,
max_seq_length=args.max_seq_length, # pyre-ignore
max_seq_length=config.sequence.max_seq_length,
)

# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
Expand All @@ -193,12 +185,12 @@ def gen_eval_wrapper(
tokenizer=tokenizer,
# Exported model takes at most (max_seq_length - 1) tokens.
# Note that the eager model takes at most max_seq_length tokens.
max_seq_length=args.max_seq_length - 1,
max_seq_length=config.sequence.max_seq_length - 1,
)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config)
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
manager: LLMEdgeManager = _prepare_for_llama_export(args)
manager: LLMEdgeManager = _prepare_for_llama_export(config)

if len(quantizers) != 0:
manager = manager.export().pt2e_quantize(quantizers)
Expand All @@ -210,9 +202,9 @@ def gen_eval_wrapper(
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache, # pyre-ignore
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
max_seq_length=config.sequence.max_seq_length,
use_kv_cache=config.kv_cache.use_kv_cache, # pyre-ignore
enable_dynamic_shape=config.misc.enable_dynamic_shape, # pyre-ignore
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
Expand All @@ -230,18 +222,94 @@ def gen_eval_wrapper(
# that is not available in this eval_llama. We save the checkpoint
# here for consistency with eval_llama. The accuracy results we
# get from eval_llama can be used as a reference to other evaluations.
if args.output_eager_checkpoint_file is not None: # pyre-ignore
torch.save(model, args.output_eager_checkpoint_file)
if config.eval.output_eager_checkpoint_file is not None: # pyre-ignore
torch.save(model, config.eval.output_eager_checkpoint_file)

return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
max_seq_length=config.sequence.max_seq_length,
use_kv_cache=config.kv_cache.use_kv_cache,
)


def eval_llama(
model_name: str,
config: DictConfig,
) -> None:
# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, config)

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
if config.eval.tasks and "mmlu" in config.eval.tasks:
import datasets

datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

# Evaluate the model
tasks = (
None if config.eval.tasks is None else OmegaConf.to_container(config.eval.tasks)
)
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=tasks,
num_fewshot=config.eval.num_fewshot,
limit=config.eval.limit,
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def eval_llama_with_attention_sink(
model_name: str,
config: DictConfig,
) -> None:
# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, config)

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
if config.eval.tasks and "mmlu" in config.eval.tasks:
import datasets

datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

# Evaluate the model
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=OmegaConf.to_container(config.eval.tasks),
num_fewshot=config.eval.num_fewshot,
limit=config.eval.limit,
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def _convert_cli_to_config_format(args) -> DictConfig:
"""Convert CLI arguments to config format."""
# First convert common args using the shared function
config = _convert_args_to_config(args)

# Add evaluation-specific settings
config.eval = OmegaConf.create()
config.eval.tasks = args.tasks
config.eval.limit = args.limit
config.eval.num_fewshot = args.num_fewshot
config.eval.pte = args.pte
config.eval.tokenizer_bin = args.tokenizer_bin
config.eval.output_eager_checkpoint_file = args.output_eager_checkpoint_file
Comment on lines +299 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry is there a definition of the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is just the first step: to convert args to config to unblock internal work, where the configs can be used standalone, without cli args and yaml file. More context in #9449

config.eval.attention_sink_eval_tokens = args.attention_sink_eval_tokens

return config


def build_args_parser() -> argparse.ArgumentParser:
"""Build argument parser for evaluation, extending the export parser with eval-specific args."""
# Start with arg parser from export_llama_lib
parser = _build_args_parser()

Expand Down Expand Up @@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)

# Set of parameters secpific to AttentionSink.
# Set of parameters specific to AttentionSink.
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)

return parser


def eval_llama(
model_name: str,
args: argparse.ArgumentParser,
) -> None:
# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, args)

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
if args.tasks and "mmlu" in args.tasks:
import datasets

datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

# Evaluate the model
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
"""
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
with torch.no_grad():
num_tokens = min(
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
)
logits = model(
torch.tensor(
[tokens[:num_tokens]], dtype=torch.int64, device=device
),
torch.tensor([input_pos], dtype=torch.int64, device=device),
).squeeze(dim=0)
neg_log_likelihood = loss_fn(
logits,
torch.tensor(
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
).view(-1),
)
nlls.append(neg_log_likelihood)
input_pos += num_tokens
progress_bar.update(num_tokens)
if input_pos >= args.attention_sink_eval_tokens:
break
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()
Loading
Loading