Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"fbsource//third-party/pypi/tqdm:tqdm",
"fbsource//third-party/pypi/datasets:datasets",
"fbsource//third-party/pypi/lm-eval:lm-eval",
"fbsource//third-party/pypi/tiktoken:tiktoken",
":export_library",
Expand Down
11 changes: 9 additions & 2 deletions examples/models/llama/eval_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import torch

from .eval_llama_lib import build_args_parser, eval_llama
from .eval_llama_lib import (
build_args_parser,
eval_llama,
eval_llama_with_attention_sink,
)

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand All @@ -24,7 +28,10 @@ def main() -> None:
args = parser.parse_args()
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True
eval_llama(modelname, args) # pyre-ignore
if args.use_attention_sink:
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
else:
eval_llama(modelname, args) # pyre-ignore


if __name__ == "__main__":
Expand Down
64 changes: 64 additions & 0 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Optional, Union

import torch

from datasets import load_dataset
from executorch.examples.models.llama.export_llama_lib import (
get_quantizer_and_quant_params,
)
Expand All @@ -21,6 +23,8 @@
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.evaluator import simple_evaluate
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from .evaluate.eager_eval import EagerEvalWrapper

Expand Down Expand Up @@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)

# Set of parameters secpific to AttentionSink.
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)

return parser


Expand Down Expand Up @@ -309,3 +316,60 @@ def eval_llama(

for task, res in eval_results["results"].items():
print(f"{task}: {res}")


def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
"""
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
with torch.no_grad():
num_tokens = min(
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
)
logits = model(
torch.tensor(
[tokens[:num_tokens]], dtype=torch.int64, device=device
),
torch.tensor([input_pos], dtype=torch.int64, device=device),
).squeeze(dim=0)
neg_log_likelihood = loss_fn(
logits,
torch.tensor(
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
).view(-1),
)
nlls.append(neg_log_likelihood)
input_pos += num_tokens
progress_bar.update(num_tokens)
if input_pos >= args.attention_sink_eval_tokens:
break
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()
7 changes: 7 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)

parser.add_argument(
"--use_attention_sink",
default=None,
type=str,
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
)

parser.add_argument(
"--output_prune_map",
default=None,
Expand Down
19 changes: 19 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,25 @@ def __init__(self, **kwargs):

sanitize_checkpoint_from_pre_quantization(checkpoint)

if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_seq_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
params=model_args,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
Expand Down
6 changes: 5 additions & 1 deletion examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
with torch.no_grad():
runner = runner_class(args) # pyre-ignore: Missing argument [20]
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
runner.chat_completion(
max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length,
temperature=args.temperature,
show_progress=args.show_tokens,
)
if args.chat
else runner.text_completion(
prompt=args.prompt,
Expand Down
27 changes: 17 additions & 10 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,19 @@ def text_completion(

def chat_completion(
self,
max_seq_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
show_progress: bool = False,
) -> List[int]:
"""
Perform multi-turn chat with the language model.

Args:
prompt (str): Text prompt for completion.
max_seq_len (int): Maximum number of tokens to generate for each prompt.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

show_progress (bool, optional): Flag indicating whether to show number of tokens generated.
Returns:
Generated list of tokens.

Expand All @@ -188,20 +189,26 @@ def chat_completion(
"""
exit_prompt = "exit"
tokens = []
pre_stop_token = []
prompt = input("Me: ")
while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
new_tokens = self.generate(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
max_seq_len=self.max_seq_len,
prompt_tokens = self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
)
generated_tokens = self.generate(
prompt_tokens=pre_stop_token + prompt_tokens,
max_seq_len=max_seq_len,
temperature=temperature,
top_p=top_p,
echo=True,
echo=False,
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
)
tokens.extend(new_tokens)
pre_stop_token = generated_tokens[-1:]
tokens.extend(prompt_tokens)
tokens.extend(generated_tokens)
if show_progress:
print(f"[Generated {len(tokens)} tokens]")
prompt = input("Me: ")
return tokens

Expand Down
118 changes: 117 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

import types
from typing import Optional

import torch

from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
from executorch.examples.models.llama.llama_transformer import (
Attention,
KVCache,
ModelArgs,
Rope,
)
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


class RopeWithAttentionSink(Rope):
Expand Down Expand Up @@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return self.position_shift


def attention_sink_forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)


def _replace_rope(
module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
return isinstance(child, Rope)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
return rope_with_attention_sink

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def _replace_attention(
module: torch.nn.Module,
rope_with_attention_sink: RopeWithAttentionSink,
sink_size: int,
window_size: int,
eviction_batch_size: int,
):
for _, child_module in module._modules.items():
if len(list(child_module.children())) > 0: # pyre-ignore [16]
_replace_attention(
module=child_module, # pyre-ignore [6]
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

if isinstance(child_module, Attention):
kv_cache = child_module.kv_cache
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
n_heads=kv_cache.n_heads,
head_dim=kv_cache.head_dim,
transpose_cache=kv_cache.transpose_cache,
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
rope=rope_with_attention_sink,
max_batch_size=kv_cache.max_batch_size,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
dtype=kv_cache.k_cache.dtype,
)
child_module.kv_cache = kv_cache_with_attention_sink
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
child_module.forward = types.MethodType( # pyre-ignore
attention_sink_forward, child_module
)


def enable_attention_sink(
module: torch.nn.Module,
params: ModelArgs,
sink_size: int,
window_size: int,
eviction_batch_size: int,
) -> torch.nn.Module:
"""
Transform the model to be able to run inference with Attention Sink.
There mainly three steps:
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
"""
rope_with_attention_sink = RopeWithAttentionSink(
params=params,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
)
_replace_rope(module, rope_with_attention_sink)
_replace_attention(
module=module,
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)
return module
Loading