33"""
44Patch-Perplexity (P3L)
55
6- This is a script that produces a realistic PPL measurement
7- for the quantized KV cache system by processing a sequence of
8- non-overlapping patches of the reference text. Generation of the
6+ This is a script that produces a realistic PPL measurement
7+ for the quantized KV cache system by processing a sequence of
8+ non-overlapping patches of the reference text. Generation of the
99consecutive symbols in each patch is governed (forced)
1010by the reference text.
1111
12- The initial context size for the system is set by the parameter
12+ The initial context size for the system is set by the parameter
1313"--context-size".
1414
15- The number of output symbols to generate starting from a given
16- context is set by the parameter "--sample-size". This variable also
15+ The number of output symbols to generate starting from a given
16+ context is set by the parameter "--sample-size". This variable also
1717defines the size of the individual patch.
1818
19- For the N-token reference text that is split into M patches with the
19+ For the N-token reference text that is split into M patches with the
2020system's context size C it takes M*preload + (N-C)*generation time.
2121
2222Quick correctness validation tips:
2323
24- Running llama-2-7b model
25- (
26- ./vllm/examples/P3L.py
27- --model=meta-llama/Llama-2-7b-chat-hf
28- --context-size=1024
24+ Running llama-2-7b model
25+ (
26+ ./vllm/examples/P3L.py
27+ --model=meta-llama/Llama-2-7b-chat-hf
28+ --context-size=1024
2929 --sample-size=512
3030)
3131should result in PPL ~ 6.524227946419175
3232
33- Running llama-2-7b model
34- (
35- ./vllm/examples/P3L.py
36- --model=meta-llama/Llama-2-7b-chat-hf
37- --context-size=1024
33+ Running llama-2-7b model
34+ (
35+ ./vllm/examples/P3L.py
36+ --model=meta-llama/Llama-2-7b-chat-hf
37+ --context-size=1024
3838 --sample-size=512
3939 --patch-size=1
4040)
5858from vllm import LLM , SamplingParams
5959from vllm .engine .arg_utils import EngineArgs
6060from vllm .logger import init_logger
61+ from vllm .utils import FlexibleArgumentParser
6162
6263logger = init_logger (__name__ )
6364
6465
6566def get_wikitext2_text (tokenizer ):
6667 with tempfile .TemporaryDirectory () as tmpdirname :
67- hf_hub_download (repo_id = 'alexei-v-ivanov-amd/wiki' ,
68- repo_type = "dataset" ,
69- filename = 'wiki.test.raw' ,
70- local_dir = tmpdirname )
71- with open (os .path .join (tmpdirname , 'wiki.test.raw' )) as f :
68+ hf_hub_download (
69+ repo_id = "alexei-v-ivanov-amd/wiki" ,
70+ repo_type = "dataset" ,
71+ filename = "wiki.test.raw" ,
72+ local_dir = tmpdirname ,
73+ )
74+ with open (os .path .join (tmpdirname , "wiki.test.raw" )) as f :
7275 test_text = "\n " .join (line .strip () for line in f )
7376 test_enc = tokenizer (test_text )
7477
@@ -79,15 +82,17 @@ def vllm_init(args):
7982 engine_args = EngineArgs .from_cli_args (args )
8083 llm = LLM (** dataclasses .asdict (engine_args ))
8184
82- sampling_params = SamplingParams (n = 1 ,
83- temperature = 0.0 ,
84- top_p = 1 ,
85- ignore_eos = True ,
86- ppl_measurement = True ,
87- future_context = [],
88- prompt_logprobs = 1 ,
89- logprobs = 1 ,
90- presence_penalty = 0.0 )
85+ sampling_params = SamplingParams (
86+ n = 1 ,
87+ temperature = 0.0 ,
88+ top_p = 1 ,
89+ ignore_eos = True ,
90+ ppl_measurement = True ,
91+ future_context = [],
92+ prompt_logprobs = 1 ,
93+ logprobs = 1 ,
94+ presence_penalty = 0.0 ,
95+ )
9196
9297 return llm , sampling_params
9398
@@ -98,7 +103,6 @@ def vllm_predict(CONT, llm, sampl_par):
98103
99104
100105def main (args : argparse .Namespace ):
101-
102106 MESSAGE = f"Initialising @ { datetime .datetime .now ()} "
103107 logger .info (MESSAGE )
104108 print (MESSAGE )
@@ -112,14 +116,17 @@ def main(args: argparse.Namespace):
112116
113117 my_n_samples = args .sample_size
114118
115- if (args .context_size + my_n_samples ) > \
116- my_llm .llm_engine .model_config .max_model_len :
117- MESSAGE = ("" \
118- "Error! The total number of tokens:\n " \
119- f" prefix ({ args .context_size } ) + " \
120- f"to be generated ({ my_n_samples } )" \
121- f" can't be bigger than the model limit " \
122- f"({ my_llm .llm_engine .model_config .max_model_len } )." )
119+ if (
120+ args .context_size + my_n_samples
121+ ) > my_llm .llm_engine .model_config .max_model_len :
122+ MESSAGE = (
123+ ""
124+ "Error! The total number of tokens:\n "
125+ f" prefix ({ args .context_size } ) + "
126+ f"to be generated ({ my_n_samples } )"
127+ f" can't be bigger than the model limit "
128+ f"({ my_llm .llm_engine .model_config .max_model_len } )."
129+ )
123130 logger .info (MESSAGE )
124131 print (MESSAGE )
125132 return
@@ -128,26 +135,28 @@ def main(args: argparse.Namespace):
128135 logger .info ("Loaded the test data." )
129136
130137 my_n_patches = math .ceil (
131- (len (my_test_enc ['input_ids' ]) - args .context_size - 1 ) / my_n_samples )
138+ (len (my_test_enc ["input_ids" ]) - args .context_size - 1 ) / my_n_samples
139+ )
132140 if args .patch_size is not None :
133141 my_n_patches = args .patch_size
134142
135143 num_tokens_generated = 0
136144 starting_time = datetime .datetime .now ()
137- MESSAGE = (f"Starting generation @ { starting_time } \n " \
138- " Have the test sample of "
139- f"{ len (my_test_enc ['input_ids' ])} tokens" \
140- f" will try to process { my_n_patches } patche(s)," \
141- f" generating { my_n_samples } tokens in each patch" \
142- f" from the initial context of { args .context_size } tokens." )
145+ MESSAGE = (
146+ f"Starting generation @ { starting_time } \n "
147+ " Have the test sample of "
148+ f"{ len (my_test_enc ['input_ids' ])} tokens"
149+ f" will try to process { my_n_patches } patche(s),"
150+ f" generating { my_n_samples } tokens in each patch"
151+ f" from the initial context of { args .context_size } tokens."
152+ )
143153
144154 logger .info (MESSAGE )
145155 print (MESSAGE )
146156
147157 my_batchsize = args .batch_size
148158
149159 for c in range (0 , my_n_patches , my_batchsize ):
150-
151160 CONTEXT = []
152161 my_sampl_par .future_context = []
153162 my_sampl_par .cntr = []
@@ -156,53 +165,68 @@ def main(args: argparse.Namespace):
156165 if (c + b ) < my_n_patches :
157166 upper_boundary = min (
158167 (c + b + 1 ) * my_n_samples + args .context_size ,
159- len (my_test_enc ['input_ids' ]))
168+ len (my_test_enc ["input_ids" ]),
169+ )
160170 CONTEXT .append (
161- my_test_enc ['input_ids' ][(c + b ) * my_n_samples :(c + b ) *
162- my_n_samples + args .context_size ])
171+ my_test_enc ["input_ids" ][
172+ (c + b ) * my_n_samples : (c + b ) * my_n_samples
173+ + args .context_size
174+ ]
175+ )
163176
164177 my_sampl_par .future_context .append (
165- my_test_enc ['input_ids' ][(c + b ) * my_n_samples +
166- args .context_size :upper_boundary ])
178+ my_test_enc ["input_ids" ][
179+ (c + b ) * my_n_samples + args .context_size : upper_boundary
180+ ]
181+ )
167182
168183 my_sampl_par .cntr .append (c + b )
169184
170185 my_sampl_par .max_tokens = max (
171- len (my_sampl_par .future_context [b ]) for b in range (len (CONTEXT )))
186+ len (my_sampl_par .future_context [b ]) for b in range (len (CONTEXT ))
187+ )
172188
173189 LOGPROBS = vllm_predict (CONTEXT , my_llm , my_sampl_par )
174190 for b in range (len (CONTEXT )):
175191 num_tokens_generated += len (LOGPROBS [b ].outputs [0 ].token_ids )
176192 my_ppl -= LOGPROBS [b ].outputs [0 ].cumulative_logprob
177193
178- if (num_tokens_generated < my_n_samples * len (CONTEXT )):
179- MESSAGE = (f"Warning: The number of generated tokens is" \
180- f"less than requested ({ num_tokens_generated } " \
181- f" < { my_n_samples * len (CONTEXT )} )." )
194+ if num_tokens_generated < my_n_samples * len (CONTEXT ):
195+ MESSAGE = (
196+ f"Warning: The number of generated tokens is"
197+ f"less than requested ({ num_tokens_generated } "
198+ f" < { my_n_samples * len (CONTEXT )} )."
199+ )
182200 logger .info (MESSAGE )
183201 print (MESSAGE )
184202
185- MESSAGE = (f"Iterations { c + 1 } through { c + len (CONTEXT )} " \
186- f" of { my_n_patches } Intermediate " \
187- "Estimates:\n " \
188- f"\t Cross-entropy_intermediate={ my_ppl / num_tokens_generated } \n " \
189- f"\t Perplexity_intermediate=" \
190- f"{ math .exp (my_ppl / num_tokens_generated )} " )
203+ MESSAGE = (
204+ f"Iterations { c + 1 } through { c + len (CONTEXT )} "
205+ f" of { my_n_patches } Intermediate "
206+ "Estimates:\n "
207+ f"\t Cross-entropy_intermediate={ my_ppl / num_tokens_generated } \n "
208+ f"\t Perplexity_intermediate="
209+ f"{ math .exp (my_ppl / num_tokens_generated )} "
210+ )
191211
192212 logger .info (MESSAGE )
193213 print (MESSAGE )
194214
195215 ending_time = datetime .datetime .now ()
196- MESSAGE = (f"Done @ { ending_time } after processing for" \
197- f" { ending_time - starting_time } " \
198- f" generated { num_tokens_generated } tokens." )
216+ MESSAGE = (
217+ f"Done @ { ending_time } after processing for"
218+ f" { ending_time - starting_time } "
219+ f" generated { num_tokens_generated } tokens."
220+ )
199221
200222 logger .info (MESSAGE )
201223 print (MESSAGE )
202224
203- MESSAGE = (f"\t Integral Cross-Entropy={ my_ppl } \n \t Average Cross-Entropy=" \
204- f"{ my_ppl / num_tokens_generated } " \
205- f"\n \t PPL={ math .exp (my_ppl / num_tokens_generated )} " )
225+ MESSAGE = (
226+ f"\t Integral Cross-Entropy={ my_ppl } \n \t Average Cross-Entropy="
227+ f"{ my_ppl / num_tokens_generated } "
228+ f"\n \t PPL={ math .exp (my_ppl / num_tokens_generated )} "
229+ )
206230
207231 if args .output_json :
208232 results = {
@@ -219,17 +243,19 @@ def main(args: argparse.Namespace):
219243
220244
221245if __name__ == "__main__" :
222- parser = argparse .ArgumentParser (
223- description = 'Measure the PPPL (P3L) score of a given model.' )
224- parser .add_argument ('--context-size' , type = int , default = 4096 )
225- parser .add_argument ('--sample-size' , type = int , default = 512 )
226- parser .add_argument ('--batch-size' , type = int , default = 1 )
227- parser .add_argument ('--patch-size' , type = int , default = None )
246+ parser = FlexibleArgumentParser (
247+ description = "Measure the PPPL (P3L) score of a given model."
248+ )
249+ parser .add_argument ("--context-size" , type = int , default = 4096 )
250+ parser .add_argument ("--sample-size" , type = int , default = 512 )
251+ parser .add_argument ("--batch-size" , type = int , default = 1 )
252+ parser .add_argument ("--patch-size" , type = int , default = None )
228253 parser .add_argument (
229- ' --output-json' ,
254+ " --output-json" ,
230255 type = str ,
231256 default = None ,
232- help = 'Path to save the latency results in JSON format.' )
257+ help = "Path to save the latency results in JSON format." ,
258+ )
233259
234260 parser = EngineArgs .add_cli_args (parser )
235261 args = parser .parse_args ()
0 commit comments