Skip to content

Commit 52513bc

Browse files
author
Xinyu
committed
update
1 parent a5bd107 commit 52513bc

File tree

2 files changed

+44
-50
lines changed

2 files changed

+44
-50
lines changed

ape/__init__.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1-
from .ape_gemma import *
2-
from .ape_llama import *
3-
from .ape_mistral import *
1+
def enable_attention_prefill_prefix(model_name, model):
2+
if "llama" in model_name:
3+
from .ape_llama import enable_llama_attention_prefill_prefix
4+
enable_llama_attention_prefill_prefix(model)
5+
elif "mistral" in model_name:
6+
from .ape_mistral import enable_mistral_attention_prefill_prefix
7+
enable_mistral_attention_prefill_prefix(model)
8+
elif "gemma" in model_name:
9+
from .ape_gemma import enable_gemma_attention_prefill_prefix
10+
enable_gemma_attention_prefill_prefix(model)
11+
12+
def enable_attention_prefill_context(model_name, model):
13+
if "llama" in model_name:
14+
from .ape_llama import enable_llama_attention_prefill_context
15+
enable_llama_attention_prefill_context(model)
16+
elif "mistral" in model_name:
17+
from .ape_mistral import enable_mistral_attention_prefill_context
18+
enable_mistral_attention_prefill_context(model)
19+
elif "gemma" in model_name:
20+
from .ape_gemma import enable_gemma_attention_prefill_context
21+
enable_gemma_attention_prefill_context(model)
22+
23+
def enable_attention_prefill_query(model_name, model, temperature, scale):
24+
if "llama" in model_name:
25+
from .ape_llama import enable_llama_attention_prefill_query
26+
enable_llama_attention_prefill_query(model, temperature, scale)
27+
elif "mistral" in model_name:
28+
from .ape_mistral import enable_mistral_attention_prefill_query
29+
enable_mistral_attention_prefill_query(model, temperature, scale)
30+
elif "gemma" in model_name:
31+
from .ape_gemma import enable_gemma_attention_prefill_query
32+
enable_gemma_attention_prefill_query(model, temperature, scale)

demo_ape.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
import random
55
import argparse
66

7+
from ape import enable_attention_prefill_prefix, enable_attention_prefill_context, enable_attention_prefill_query
8+
79
def parse_args(args=None):
810
parser = argparse.ArgumentParser()
911
parser.add_argument('--model', type=str, default=None, choices=["llama3-8b-instruct", "llama3.1-8b-instruct", "mistral-7b-instruct-v0.3", "gemma2-9b-it"])
1012
parser.add_argument("--temperature", type=float, default=0.9)
1113
parser.add_argument("--scale", type=float, default=0.9)
1214
return parser.parse_args(args)
1315

16+
def seed_everything(seed):
17+
torch.manual_seed(seed)
18+
torch.cuda.manual_seed(seed)
19+
np.random.seed(seed)
20+
random.seed(seed)
21+
torch.backends.cudnn.benchmark = False
22+
torch.backends.cudnn.deterministic = True
23+
torch.cuda.manual_seed_all(seed)
24+
1425
def load_model_and_tokenizer(model_name, device):
1526
if model_name == "llama3-8b-instruct":
1627
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
@@ -26,7 +37,6 @@ def load_model_and_tokenizer(model_name, device):
2637
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", torch_dtype=torch.bfloat16).to(device)
2738
return tokenizer, model
2839

29-
3040
def build_prefix(model_name, prompt):
3141
if "llama" in model_name:
3242
prompt = f"<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n{prompt}"
@@ -45,48 +55,6 @@ def build_suffix(model_name, prompt):
4555
prompt = f"{prompt}<end_of_turn>\n<start_of_turn>model\n"
4656
return prompt
4757

48-
def enable_attention_prefill_prefix(model_name, model):
49-
if "llama" in args.model:
50-
from ape.ape_llama import enable_llama_attention_prefill_prefix
51-
enable_llama_attention_prefill_prefix(model)
52-
elif "mistral" in model_name:
53-
from ape.ape_mistral import enable_mistral_attention_prefill_prefix
54-
enable_mistral_attention_prefill_prefix(model)
55-
elif "gemma" in model_name:
56-
from ape.ape_gemma import enable_gemma_attention_prefill_prefix
57-
enable_gemma_attention_prefill_prefix(model)
58-
59-
def enable_attention_prefill_context(model_name, model):
60-
if "llama" in args.model:
61-
from ape.ape_llama import enable_llama_attention_prefill_context
62-
enable_llama_attention_prefill_context(model)
63-
elif "mistral" in model_name:
64-
from ape.ape_mistral import enable_mistral_attention_prefill_context
65-
enable_mistral_attention_prefill_context(model)
66-
elif "gemma" in model_name:
67-
from ape.ape_gemma import enable_gemma_attention_prefill_context
68-
enable_gemma_attention_prefill_context(model)
69-
70-
def enable_attention_prefill_query(model_name, model, temperature, scale):
71-
if "llama" in args.model:
72-
from ape.ape_llama import enable_llama_attention_prefill_query
73-
enable_llama_attention_prefill_query(model, temperature, scale)
74-
elif "mistral" in model_name:
75-
from ape.ape_mistral import enable_mistral_attention_prefill_query
76-
enable_mistral_attention_prefill_query(model, temperature, scale)
77-
elif "gemma" in model_name:
78-
from ape.ape_gemma import enable_gemma_attention_prefill_query
79-
enable_gemma_attention_prefill_query(model, temperature, scale)
80-
81-
def seed_everything(seed):
82-
torch.manual_seed(seed)
83-
torch.cuda.manual_seed(seed)
84-
np.random.seed(seed)
85-
random.seed(seed)
86-
torch.backends.cudnn.benchmark = False
87-
torch.backends.cudnn.deterministic = True
88-
torch.cuda.manual_seed_all(seed)
89-
9058
def generate(args):
9159
prefix = ""
9260
contexts = [
@@ -98,8 +66,6 @@ def generate(args):
9866
]
9967
query = "Question: what are ten ideas for a social with a large groups of friends in New York City.\nAnswer:"
10068

101-
102-
10369
device = torch.device(f'cuda:0')
10470
tokenizer, model = load_model_and_tokenizer(args.model, device)
10571
model = model.eval()
@@ -111,9 +77,7 @@ def generate(args):
11177
query_input_ids = tokenizer(query, truncation=False, return_tensors="pt").input_ids
11278
len_prefix = prefix_input_ids.shape[1]
11379
len_query = query_input_ids.shape[1]
114-
11580
context_input_ids = tokenizer(contexts, return_tensors='pt', truncation=True, max_length=8192-len_prefix-len_query-256, padding=True, add_special_tokens=False).input_ids
116-
print(context_input_ids.shape)
11781
context_mask = (context_input_ids != tokenizer.pad_token_id).reshape(-1)
11882

11983
enable_attention_prefill_prefix(args.model, model)
@@ -149,6 +113,7 @@ def generate(args):
149113
past_position = torch.cat([past_key_value[2][:, :len_prefix],
150114
past_key_value[2][:, len_prefix:].repeat(bsz, 1).flatten()[context_mask].unsqueeze(0)], dim=1)
151115
past_key_values.append((past_key, past_value, past_position, len(contexts)))
116+
152117
context_input_ids = context_input_ids.flatten()[context_mask].unsqueeze(0)
153118
input_ids = torch.cat([prefix_input_ids, context_input_ids, query_input_ids], dim=-1)
154119
context_length = input_ids.shape[-1]

0 commit comments

Comments
 (0)