1
+ """
2
+ Benchmark the efficiency of prefix caching.
3
+
4
+ This script allows you to benchmark the performance of
5
+ a model with and without prefix caching using either fixed prompts
6
+ or prompts sampled from the ShareGPT dataset.
7
+
8
+ Fixed example usage:
9
+ python benchmark_prefix_caching.py \
10
+ --model meta-llama/Llama-2-7b-chat-hf \
11
+ --enable-prefix-caching \
12
+ --num-prompts 1 \
13
+ --repeat-count 100
14
+
15
+ ShareGPT example usage:
16
+ # This command samples 20 prompts with input lengths
17
+ # between 128 and 256 tokens from the ShareGPT dataset,
18
+ # then replicates each prompt 5 times.
19
+ python benchmark_prefix_caching.py \
20
+ --model meta-llama/Llama-2-7b-chat-hf \
21
+ --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
22
+ --enable-prefix-caching \
23
+ --num-prompts 20 \
24
+ --repeat-count 5 \
25
+ --input-length-range 128:256
26
+ """
27
+
28
+ import json
29
+ import random
1
30
import time
31
+ from typing import List , Optional , Tuple
32
+
33
+ from transformers import PreTrainedTokenizerBase
2
34
3
35
from vllm import LLM , SamplingParams
4
36
from vllm .utils import FlexibleArgumentParser
5
37
38
+ try :
39
+ from vllm .transformers_utils .tokenizer import get_tokenizer
40
+ except ImportError :
41
+ from backend_request_func import get_tokenizer
42
+
6
43
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n # Table\n |Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n |----|----|----|----|----|----|----|----|\n |J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n |J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n |J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n |J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n |F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n |F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n |F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n |F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n |F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n |F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n |M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n |M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n |M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n |M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n |M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n \n # Question\n What' s the content in the (1,1) cells\n " # noqa: E501
7
44
8
45
@@ -15,7 +52,83 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
15
52
print (f"cost time { end_time - start_time } " )
16
53
17
54
55
+ def sample_requests (
56
+ dataset_path : str ,
57
+ num_requests : int ,
58
+ tokenizer : PreTrainedTokenizerBase ,
59
+ input_length_range : Tuple [int , int ],
60
+ fixed_output_len : Optional [int ],
61
+ ) -> List [Tuple [str , int , int ]]:
62
+ if fixed_output_len is not None and fixed_output_len < 4 :
63
+ raise ValueError ("output_len too small" )
64
+
65
+ # Load the dataset.
66
+ with open (dataset_path ) as f :
67
+ dataset = json .load (f )
68
+ # Filter out the conversations with less than 2 turns.
69
+ dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
70
+ # Only keep the first two turns of each conversation.
71
+ dataset = [(data ["conversations" ][0 ]["value" ],
72
+ data ["conversations" ][1 ]["value" ]) for data in dataset ]
73
+
74
+ # Shuffle the dataset.
75
+ random .shuffle (dataset )
76
+
77
+ min_len , max_len = input_length_range
78
+
79
+ # Filter out sequences that are too long or too short
80
+ filtered_dataset : List [Tuple [str , int , int ]] = []
81
+ for i in range (len (dataset )):
82
+ if len (filtered_dataset ) == num_requests :
83
+ break
84
+
85
+ # Tokenize the prompts and completions.
86
+ prompt = dataset [i ][0 ]
87
+ prompt_token_ids = tokenizer (prompt ).input_ids
88
+ completion = dataset [i ][1 ]
89
+ completion_token_ids = tokenizer (completion ).input_ids
90
+ prompt_len = len (prompt_token_ids )
91
+ output_len = len (completion_token_ids
92
+ ) if fixed_output_len is None else fixed_output_len
93
+ if prompt_len < 4 or output_len < 4 :
94
+ # Prune too short sequences.
95
+ continue
96
+ if min_len <= prompt_len <= max_len :
97
+ filtered_dataset .append ((prompt , prompt_len , output_len ))
98
+
99
+ return filtered_dataset
100
+
101
+
102
+ def repeat_and_sort_requests (requests : List [Tuple [str , int , int ]],
103
+ repeat_count : int ,
104
+ sort : bool = False ) -> List [str ]:
105
+ repeated_requests = requests * repeat_count
106
+ if sort :
107
+ repeated_requests .sort (key = lambda x : x [1 ])
108
+ else :
109
+ random .shuffle (repeated_requests )
110
+ return [req [0 ] for req in repeated_requests ]
111
+
112
+
18
113
def main (args ):
114
+ tokenizer = get_tokenizer (args .model , trust_remote_code = True )
115
+ input_length_range = tuple (map (int , args .input_length_range .split (':' )))
116
+
117
+ if args .dataset_path is not None :
118
+ print (f"Start to sample { args .num_prompts } prompts"
119
+ "from {args.dataset_path}" )
120
+ filtered_datasets = sample_requests (
121
+ dataset_path = args .dataset_path ,
122
+ num_requests = args .num_prompts ,
123
+ tokenizer = tokenizer ,
124
+ input_length_range = input_length_range ,
125
+ fixed_output_len = args .output_len ,
126
+ )
127
+ else :
128
+ prompt_len = len (tokenizer (PROMPT ).input_ids )
129
+ filtered_datasets = [(PROMPT , prompt_len , args .output_len )
130
+ ] * args .num_prompts
131
+
19
132
llm = LLM (model = args .model ,
20
133
tokenizer_mode = 'auto' ,
21
134
trust_remote_code = True ,
@@ -24,10 +137,13 @@ def main(args):
24
137
tensor_parallel_size = args .tensor_parallel_size ,
25
138
enable_prefix_caching = args .enable_prefix_caching )
26
139
27
- num_prompts = 100
28
- prompts = [PROMPT ] * num_prompts
29
140
sampling_params = SamplingParams (temperature = 0 , max_tokens = args .output_len )
30
141
142
+ print ("Testing filtered datasets" )
143
+ prompts = repeat_and_sort_requests (filtered_datasets ,
144
+ repeat_count = args .repeat_count ,
145
+ sort = args .sort )
146
+
31
147
print ("------warm up------" )
32
148
test_prefix (
33
149
llm = llm ,
@@ -45,11 +161,15 @@ def main(args):
45
161
46
162
if __name__ == "__main__" :
47
163
parser = FlexibleArgumentParser (
48
- description = 'Benchmark the performance with or without automatic '
49
- 'prefix caching.' )
164
+ description =
165
+ 'Benchmark the performance with or without automatic prefix caching.' )
50
166
parser .add_argument ('--model' ,
51
167
type = str ,
52
168
default = 'baichuan-inc/Baichuan2-13B-Chat' )
169
+ parser .add_argument ("--dataset-path" ,
170
+ type = str ,
171
+ default = None ,
172
+ help = "Path to the dataset." )
53
173
parser .add_argument ('--tensor-parallel-size' , '-tp' , type = int , default = 1 )
54
174
parser .add_argument ('--output-len' , type = int , default = 10 )
55
175
parser .add_argument ('--enable-prefix-caching' ,
@@ -58,5 +178,21 @@ def main(args):
58
178
parser .add_argument ('--use-v2-block-manager' ,
59
179
action = 'store_true' ,
60
180
help = 'Use BlockSpaceMangerV2' )
181
+ parser .add_argument ('--num-prompts' ,
182
+ type = int ,
183
+ default = 1 ,
184
+ help = "Number of the prompts sampled from dataset" )
185
+ parser .add_argument ('--repeat-count' ,
186
+ type = int ,
187
+ default = 100 ,
188
+ help = 'Number of times to repeat each prompt' )
189
+ parser .add_argument ('--sort' ,
190
+ action = 'store_true' ,
191
+ help = 'Sort prompts by input length' )
192
+ parser .add_argument ('--input-length-range' ,
193
+ type = str ,
194
+ default = '128:256' ,
195
+ help = 'Range of input lengths for sampling prompts,'
196
+ 'specified as "min:max" (e.g., "128:256").' )
61
197
args = parser .parse_args ()
62
198
main (args )
0 commit comments