44import random
55import argparse
66
7+ from ape import enable_attention_prefill_prefix , enable_attention_prefill_context , enable_attention_prefill_query
8+
79def parse_args (args = None ):
810 parser = argparse .ArgumentParser ()
911 parser .add_argument ('--model' , type = str , default = None , choices = ["llama3-8b-instruct" , "llama3.1-8b-instruct" , "mistral-7b-instruct-v0.3" , "gemma2-9b-it" ])
1012 parser .add_argument ("--temperature" , type = float , default = 0.9 )
1113 parser .add_argument ("--scale" , type = float , default = 0.9 )
1214 return parser .parse_args (args )
1315
16+ def seed_everything (seed ):
17+ torch .manual_seed (seed )
18+ torch .cuda .manual_seed (seed )
19+ np .random .seed (seed )
20+ random .seed (seed )
21+ torch .backends .cudnn .benchmark = False
22+ torch .backends .cudnn .deterministic = True
23+ torch .cuda .manual_seed_all (seed )
24+
1425def load_model_and_tokenizer (model_name , device ):
1526 if model_name == "llama3-8b-instruct" :
1627 tokenizer = AutoTokenizer .from_pretrained ("meta-llama/Meta-Llama-3-8B-Instruct" )
@@ -26,7 +37,6 @@ def load_model_and_tokenizer(model_name, device):
2637 model = AutoModelForCausalLM .from_pretrained ("google/gemma-2-9b-it" , torch_dtype = torch .bfloat16 ).to (device )
2738 return tokenizer , model
2839
29-
3040def build_prefix (model_name , prompt ):
3141 if "llama" in model_name :
3242 prompt = f"<|begin_of_text|>\n <|start_header_id|>user<|end_header_id|>\n { prompt } "
@@ -45,48 +55,6 @@ def build_suffix(model_name, prompt):
4555 prompt = f"{ prompt } <end_of_turn>\n <start_of_turn>model\n "
4656 return prompt
4757
48- def enable_attention_prefill_prefix (model_name , model ):
49- if "llama" in args .model :
50- from ape .ape_llama import enable_llama_attention_prefill_prefix
51- enable_llama_attention_prefill_prefix (model )
52- elif "mistral" in model_name :
53- from ape .ape_mistral import enable_mistral_attention_prefill_prefix
54- enable_mistral_attention_prefill_prefix (model )
55- elif "gemma" in model_name :
56- from ape .ape_gemma import enable_gemma_attention_prefill_prefix
57- enable_gemma_attention_prefill_prefix (model )
58-
59- def enable_attention_prefill_context (model_name , model ):
60- if "llama" in args .model :
61- from ape .ape_llama import enable_llama_attention_prefill_context
62- enable_llama_attention_prefill_context (model )
63- elif "mistral" in model_name :
64- from ape .ape_mistral import enable_mistral_attention_prefill_context
65- enable_mistral_attention_prefill_context (model )
66- elif "gemma" in model_name :
67- from ape .ape_gemma import enable_gemma_attention_prefill_context
68- enable_gemma_attention_prefill_context (model )
69-
70- def enable_attention_prefill_query (model_name , model , temperature , scale ):
71- if "llama" in args .model :
72- from ape .ape_llama import enable_llama_attention_prefill_query
73- enable_llama_attention_prefill_query (model , temperature , scale )
74- elif "mistral" in model_name :
75- from ape .ape_mistral import enable_mistral_attention_prefill_query
76- enable_mistral_attention_prefill_query (model , temperature , scale )
77- elif "gemma" in model_name :
78- from ape .ape_gemma import enable_gemma_attention_prefill_query
79- enable_gemma_attention_prefill_query (model , temperature , scale )
80-
81- def seed_everything (seed ):
82- torch .manual_seed (seed )
83- torch .cuda .manual_seed (seed )
84- np .random .seed (seed )
85- random .seed (seed )
86- torch .backends .cudnn .benchmark = False
87- torch .backends .cudnn .deterministic = True
88- torch .cuda .manual_seed_all (seed )
89-
9058def generate (args ):
9159 prefix = ""
9260 contexts = [
@@ -98,8 +66,6 @@ def generate(args):
9866 ]
9967 query = "Question: what are ten ideas for a social with a large groups of friends in New York City.\n Answer:"
10068
101-
102-
10369 device = torch .device (f'cuda:0' )
10470 tokenizer , model = load_model_and_tokenizer (args .model , device )
10571 model = model .eval ()
@@ -111,9 +77,7 @@ def generate(args):
11177 query_input_ids = tokenizer (query , truncation = False , return_tensors = "pt" ).input_ids
11278 len_prefix = prefix_input_ids .shape [1 ]
11379 len_query = query_input_ids .shape [1 ]
114-
11580 context_input_ids = tokenizer (contexts , return_tensors = 'pt' , truncation = True , max_length = 8192 - len_prefix - len_query - 256 , padding = True , add_special_tokens = False ).input_ids
116- print (context_input_ids .shape )
11781 context_mask = (context_input_ids != tokenizer .pad_token_id ).reshape (- 1 )
11882
11983 enable_attention_prefill_prefix (args .model , model )
@@ -149,6 +113,7 @@ def generate(args):
149113 past_position = torch .cat ([past_key_value [2 ][:, :len_prefix ],
150114 past_key_value [2 ][:, len_prefix :].repeat (bsz , 1 ).flatten ()[context_mask ].unsqueeze (0 )], dim = 1 )
151115 past_key_values .append ((past_key , past_value , past_position , len (contexts )))
116+
152117 context_input_ids = context_input_ids .flatten ()[context_mask ].unsqueeze (0 )
153118 input_ids = torch .cat ([prefix_input_ids , context_input_ids , query_input_ids ], dim = - 1 )
154119 context_length = input_ids .shape [- 1 ]
0 commit comments