Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 7094103

Browse files
Adding P3L measurement to the benchmarks collection tools. (#197)
* Adding P3L measurement to the benchmarks collection tools. * . * . * . * .
1 parent a67b65b commit 7094103

File tree

3 files changed

+303
-2
lines changed

3 files changed

+303
-2
lines changed

benchmarks/P3L.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Patch-Perplexity (P3L)
4+
5+
This is a script that produces a realistic PPL measurement
6+
for the quantized KV cache system by processing a sequence of
7+
non-overlapping patches of the reference text. Generation of the
8+
consecutive symbols in each patch is governed (forced)
9+
by the reference text.
10+
11+
The initial context size for the system is set by the parameter
12+
"--context-size".
13+
14+
The number of output symbols to generate starting from a given
15+
context is set by the parameter "--sample-size". This variable also
16+
defines the size of the individual patch.
17+
18+
For the N-token reference text that is split into M patches with the
19+
system's context size C it takes M*preload + (N-C)*generation time.
20+
21+
Quick correctness validation tips:
22+
23+
Running llama-2-7b model
24+
(
25+
./vllm/examples/P3L.py
26+
--model=meta-llama/Llama-2-7b-chat-hf
27+
--context-size=1024
28+
--sample-size=512
29+
)
30+
should result in PPL ~ 6.524227946419175
31+
32+
Running llama-2-7b model
33+
(
34+
./vllm/examples/P3L.py
35+
--model=meta-llama/Llama-2-7b-chat-hf
36+
--context-size=1024
37+
--sample-size=512
38+
--patch-size=1
39+
)
40+
should result in PPL ~ PPL=3.8968611189957523
41+
42+
"""
43+
44+
import argparse
45+
import datetime
46+
import math
47+
import os
48+
49+
from huggingface_hub import hf_hub_download
50+
51+
from vllm import LLM, SamplingParams
52+
from vllm.logger import init_logger
53+
54+
logger = init_logger(__name__)
55+
56+
57+
def get_wikitext2_text(tokenizer):
58+
hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki',
59+
repo_type="dataset",
60+
filename='wiki.test.raw',
61+
local_dir='./')
62+
with open('./wiki.test.raw') as f:
63+
test_text = "\n".join(line.strip() for line in f)
64+
test_enc = tokenizer(test_text)
65+
66+
os.remove('./wiki.test.raw')
67+
68+
return test_enc, test_text
69+
70+
71+
def vllm_init(args):
72+
73+
llm = LLM(model=args.model,
74+
tensor_parallel_size=args.tensor_parallel_size,
75+
trust_remote_code=args.trust_remote_code,
76+
dtype=args.dtype,
77+
quantization=args.quantization,
78+
kv_cache_dtype=args.kv_cache_dtype,
79+
quantization_param_path=args.kv_cache_scales_path
80+
if args.kv_cache_scales_path != '' else None,
81+
enforce_eager=args.enforce_eager)
82+
83+
sampling_params = SamplingParams(n=1,
84+
temperature=0.0,
85+
top_p=1,
86+
use_beam_search=False,
87+
ignore_eos=True,
88+
ppl_measurement=True,
89+
future_context=[],
90+
prompt_logprobs=1,
91+
logprobs=1,
92+
presence_penalty=0.0)
93+
94+
return llm, sampling_params
95+
96+
97+
def vllm_predict(CONT, llm, sampl_par):
98+
result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par)
99+
return result
100+
101+
102+
def main(args: argparse.Namespace):
103+
104+
MESSAGE = f"Initialising @ {datetime.datetime.now()}"
105+
logger.info(MESSAGE)
106+
print(MESSAGE)
107+
my_ppl = 0.0
108+
109+
logger.info("Initializing the engine.")
110+
my_llm, my_sampl_par = vllm_init(args)
111+
my_tokenizer = my_llm.llm_engine.tokenizer.tokenizer
112+
logger.info(my_sampl_par)
113+
logger.info("Initialized the engine.")
114+
115+
my_n_samples = args.sample_size
116+
117+
if (args.context_size+my_n_samples) > \
118+
my_llm.llm_engine.model_config.max_model_len:
119+
MESSAGE = ("" \
120+
"Error! The total number of tokens:\n" \
121+
f" prefix ({args.context_size}) + " \
122+
f"to be generated ({my_n_samples})" \
123+
f" can't be bigger than the model limit " \
124+
f"({my_llm.llm_engine.model_config.max_model_len}).")
125+
logger.info(MESSAGE)
126+
print(MESSAGE)
127+
return
128+
129+
my_test_enc, my_test_text = get_wikitext2_text(my_tokenizer)
130+
logger.info("Loaded the test data.")
131+
132+
my_n_patches = math.ceil(
133+
(len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples)
134+
if args.patch_size is not None:
135+
my_n_patches = args.patch_size
136+
137+
num_tokens_generated = 0
138+
starting_time = datetime.datetime.now()
139+
MESSAGE = (f"Starting generation @ {starting_time}\n" \
140+
" Have the test sample of "
141+
f"{len(my_test_enc['input_ids'])} tokens" \
142+
f" will try to process {my_n_patches} patche(s)," \
143+
f" generating {my_n_samples} tokens in each patch" \
144+
f" from the initial context of {args.context_size} tokens.")
145+
146+
logger.info(MESSAGE)
147+
print(MESSAGE)
148+
for c in range(my_n_patches):
149+
CONTEXT = []
150+
my_sampl_par.future_context = []
151+
CONTEXT.append(
152+
my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples +
153+
args.context_size])
154+
upper_boundary = min((c + 1) * my_n_samples + args.context_size,
155+
len(my_test_enc['input_ids']))
156+
my_sampl_par.future_context.append(
157+
my_test_enc['input_ids'][c * my_n_samples +
158+
args.context_size:upper_boundary])
159+
my_sampl_par.max_tokens = len(my_sampl_par.future_context[0])
160+
my_sampl_par.cntr = c
161+
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
162+
num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids)
163+
if (num_tokens_generated < my_n_samples):
164+
MESSAGE = (f"Warning: The number of generated tokens is" \
165+
f"less than requested ({num_tokens_generated}" \
166+
f" < {my_n_samples}).")
167+
logger.info(MESSAGE)
168+
print(MESSAGE)
169+
my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob
170+
MESSAGE = (f"Iteration {c+1} of {my_n_patches} Intermediate" \
171+
"Estimates:\n" \
172+
f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \
173+
f"\tPerplexity_intermediate=" \
174+
f"{math.exp(my_ppl/num_tokens_generated)}")
175+
176+
logger.info(MESSAGE)
177+
print(MESSAGE)
178+
ending_time = datetime.datetime.now()
179+
MESSAGE = (f"Done @ {ending_time} after processing for" \
180+
f" {ending_time-starting_time}" \
181+
f" generated {num_tokens_generated} tokens.")
182+
183+
logger.info(MESSAGE)
184+
print(MESSAGE)
185+
186+
MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \
187+
f"{my_ppl/num_tokens_generated}" \
188+
f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}")
189+
190+
logger.info(MESSAGE)
191+
print(MESSAGE)
192+
return
193+
194+
195+
if __name__ == "__main__":
196+
parser = argparse.ArgumentParser(
197+
description='Benchmark the latency of processing a single batch of '
198+
'requests till completion.')
199+
parser.add_argument('--model', type=str, default='facebook/opt-125m')
200+
parser.add_argument(
201+
'--data',
202+
type=str,
203+
default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet')
204+
parser.add_argument('--context-size', type=int, default=4096)
205+
parser.add_argument('--kv-cache-scales-path', type=str, default='')
206+
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
207+
parser.add_argument('--quantization', type=str, default=None)
208+
parser.add_argument('--trust-remote-code',
209+
action='store_true',
210+
help='trust remote code from huggingface')
211+
parser.add_argument(
212+
'--dtype',
213+
type=str,
214+
default='auto',
215+
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
216+
help='data type for model weights and activations. '
217+
'The "auto" option will use FP16 precision '
218+
'for FP32 and FP16 models, and BF16 precision '
219+
'for BF16 models.')
220+
parser.add_argument('--sample-size', type=int, default=512)
221+
parser.add_argument('--patch-size', type=int, default=None)
222+
parser.add_argument('--enforce-eager',
223+
action='store_true',
224+
help='enforce eager mode and disable CUDA graph')
225+
parser.add_argument(
226+
"--kv-cache-dtype",
227+
type=str,
228+
choices=['auto', 'fp8_e5m2', 'fp8'],
229+
default='auto',
230+
help=
231+
'Data type for kv cache storage. If "auto", will use model data type.')
232+
args = parser.parse_args()
233+
234+
main(args)

vllm/model_executor/layers/sampler.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SampleResultArgsType:
6060
multinomial_samples: MultinomialSamplesType
6161
sample_results_dict: SampleResultsDictType
6262
sampling_metadata: SamplingMetadata
63+
forced_samples: Optional[torch.Tensor]
6364
greedy_samples: Optional[torch.Tensor]
6465
beam_search_logprobs: Optional[torch.Tensor]
6566

@@ -499,6 +500,39 @@ def _greedy_sample(
499500
return results
500501

501502

503+
def _forced_sample(
504+
selected_seq_groups: List[SequenceGroupToSample],
505+
samples: torch.Tensor,
506+
) -> List[Tuple[List[int], List[int]]]:
507+
"""Run forced sampling on a given samples.
508+
Args:
509+
selected_seq_groups: A list of sequence groups batched.
510+
samples: (num_selected_samples,) A tensor of samples. The length of
511+
samples could be smaller than selected_seq_groups if
512+
seq_group.do_sample is False.
513+
Returns:
514+
Tuple of (next_token_ids, parent_ids). The length of returned list is
515+
same as the length of selected_seq_groups. If the corresponding
516+
seq_group has do_sample=False, tuple contains ([], [])
517+
518+
The next_token_ids is guided (forced) by the id containing in the
519+
sampling_parameters.future_context property.
520+
"""
521+
samples = samples.tolist()
522+
sample_idx = 0
523+
results = []
524+
for seq_group in selected_seq_groups:
525+
seq_ids = seq_group.seq_ids
526+
num_parent_seqs = len(seq_ids)
527+
assert num_parent_seqs == 1, (
528+
"Deterministic sampling should have only one seq.")
529+
parent_ids = list(range(num_parent_seqs))
530+
next_token_ids = [samples[sample_idx]]
531+
results.append((next_token_ids, parent_ids))
532+
sample_idx += num_parent_seqs
533+
return results
534+
535+
502536
def _random_sample(
503537
selected_seq_groups: List[SequenceGroupToSample],
504538
random_samples: torch.Tensor,
@@ -697,13 +731,15 @@ def get_pythonized_sample_results(
697731
(
698732
sample_metadata,
699733
sampling_metadata,
734+
forced_samples,
700735
greedy_samples,
701736
multinomial_samples,
702737
beam_search_logprobs,
703738
sample_results_dict,
704739
) = (
705740
sample_result_args.sample_metadata,
706741
sample_result_args.sampling_metadata,
742+
sample_result_args.forced_samples,
707743
sample_result_args.greedy_samples,
708744
sample_result_args.multinomial_samples,
709745
sample_result_args.beam_search_logprobs,
@@ -714,7 +750,9 @@ def get_pythonized_sample_results(
714750
if sampling_type not in sample_metadata:
715751
continue
716752
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
717-
if sampling_type == SamplingType.GREEDY:
753+
if sampling_type == SamplingType.FORCED:
754+
sample_results = _forced_sample(seq_groups, forced_samples)
755+
elif sampling_type == SamplingType.GREEDY:
718756
sample_results = _greedy_sample(seq_groups, greedy_samples)
719757
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
720758
sample_results = _random_sample(seq_groups,
@@ -762,6 +800,7 @@ def _sample_with_torch(
762800
sample_results_dict: SampleResultsDictType = {}
763801
sample_metadata: SampleMetadataType = {}
764802
multinomial_samples: MultinomialSamplesType = {}
803+
forced_samples: Optional[torch.Tensor] = None
765804
greedy_samples: Optional[torch.Tensor] = None
766805
beam_search_logprobs: Optional[torch.Tensor] = None
767806

@@ -786,7 +825,19 @@ def _sample_with_torch(
786825
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
787826
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
788827
long_sample_indices = sample_indices.long()
789-
if sampling_type == SamplingType.GREEDY:
828+
if sampling_type == SamplingType.FORCED:
829+
if (seq_groups[0].sampling_params.future_context is not None):
830+
forced_samples = torch.tensor([
831+
seq_groups[0].sampling_params.future_context[0][min(
832+
len(sampling_metadata.seq_groups[0].seq_data[
833+
sampling_params.cntr].output_token_ids),
834+
len(seq_groups[0].sampling_params.future_context[0]) -
835+
1)]
836+
])
837+
else:
838+
forced_samples = torch.argmax(logprobs[long_sample_indices],
839+
dim=-1)
840+
elif sampling_type == SamplingType.GREEDY:
790841
greedy_samples = torch.argmax(logprobs[long_sample_indices],
791842
dim=-1)
792843

@@ -843,6 +894,7 @@ def _sample_with_torch(
843894
maybe_deferred_args = SampleResultArgsType(
844895
sampling_metadata=sampling_metadata,
845896
sample_metadata=sample_metadata,
897+
forced_samples=forced_samples,
846898
multinomial_samples=multinomial_samples,
847899
greedy_samples=greedy_samples,
848900
beam_search_logprobs=beam_search_logprobs,

0 commit comments

Comments
 (0)