11import argparse
2- import time , asyncio
3- from openai import AsyncOpenAI , AsyncAzureOpenAI
2+ import time
3+ import threading
4+ from concurrent .futures import ThreadPoolExecutor
45import uuid
56import traceback
67import numpy as np
78from transformers import AutoTokenizer
9+ from litellm import completion
810
9- # base_url - litellm proxy endpoint
10- # api_key - litellm proxy api-key, is created proxy with auth
11- litellm_client = None
1211
13-
14- async def litellm_completion (args , tokenizer , image_url = None ):
15- # Your existing code for litellm_completion goes here
12+ def litellm_completion (args , tokenizer , image_url = None ):
1613 try :
1714 if image_url :
1815 messages = [
@@ -30,16 +27,24 @@ async def litellm_completion(args, tokenizer, image_url=None):
3027 ]
3128
3229 start = time .time ()
33- response = await litellm_client .chat .completions .create (
30+
31+ additional_api_kwargs = {}
32+ if args .api_key :
33+ additional_api_kwargs ["api_key" ] = args .api_key
34+ if args .api_base :
35+ additional_api_kwargs ["api_base" ] = args .api_base
36+
37+ response = completion (
3438 model = args .model ,
3539 messages = messages ,
3640 stream = True ,
41+ ** additional_api_kwargs ,
3742 )
3843 ttft = None
3944
4045 itl_list = []
4146 content = ""
42- async for chunk in response :
47+ for chunk in response :
4348 if chunk .choices [0 ].delta .content :
4449 end_time = time .time ()
4550 if ttft is None :
@@ -52,43 +57,48 @@ async def litellm_completion(args, tokenizer, image_url=None):
5257 return content , ttft , itl_list
5358
5459 except Exception as e :
55- # If there's an exception, log the error message
5660 print (e )
5761 with open ("error_log.txt" , "a" ) as error_log :
5862 error_log .write (f"Error during completion: { str (e )} \n " )
5963 return str (e )
6064
6165
62- async def main (args ):
66+ def main (args ):
6367 n = args .num_total_responses
6468 batch_size = args .req_per_sec # Requests per second
6569 start = time .time ()
6670
67- all_tasks = []
71+ all_results = []
6872 tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
69- for i in range (0 , n , batch_size ):
70- batch = range (i , min (i + batch_size , n ))
71- for _ in batch :
72- if args .include_image :
73- # Generate a random dimension for the image
74- if args .randomize_image_dimensions :
75- y_dimension = np .random .randint (100 , 1025 )
73+
74+ with ThreadPoolExecutor (max_workers = batch_size ) as executor :
75+ for i in range (0 , n , batch_size ):
76+ batch_futures = []
77+ batch = range (i , min (i + batch_size , n ))
78+
79+ for _ in batch :
80+ if args .include_image :
81+ if args .randomize_image_dimensions :
82+ y_dimension = np .random .randint (100 , 1025 )
83+ else :
84+ y_dimension = 512
85+ image_url = f"https://placehold.co/1024x{ y_dimension } /png"
86+ future = executor .submit (
87+ litellm_completion , args , tokenizer , image_url
88+ )
7689 else :
77- y_dimension = 512
78- image_url = f"https://placehold.co/1024x{ y_dimension } /png"
79- task = asyncio .create_task (
80- litellm_completion (args , tokenizer , image_url )
81- )
82- else :
83- task = asyncio .create_task (litellm_completion (args , tokenizer ))
84- all_tasks .append (task )
85- if i + batch_size < n :
86- await asyncio .sleep (1 ) # Wait 1 second before the next batch
87-
88- all_completions = await asyncio .gather (* all_tasks )
90+ future = executor .submit (litellm_completion , args , tokenizer )
91+ batch_futures .append (future )
92+
93+ # Wait for batch to complete
94+ for future in batch_futures :
95+ all_results .append (future .result ())
96+
97+ if i + batch_size < n :
98+ time .sleep (1 ) # Wait 1 second before next batch
8999
90100 successful_completions = [
91- c for c in all_completions if isinstance (c , tuple ) and len (c ) == 3
101+ c for c in all_results if isinstance (c , tuple ) and len (c ) == 3
92102 ]
93103 ttft_list = np .array ([float (c [1 ]) for c in successful_completions ])
94104 itl_list_flattened = np .array (
@@ -101,7 +111,7 @@ async def main(args):
101111
102112 # Write errors to error_log.txt
103113 with open ("load_test_errors.log" , "a" ) as error_log :
104- for completion in all_completions :
114+ for completion in all_results :
105115 if isinstance (completion , str ):
106116 error_log .write (completion + "\n " )
107117
@@ -115,15 +125,15 @@ async def main(args):
115125if __name__ == "__main__" :
116126 parser = argparse .ArgumentParser ()
117127 parser .add_argument ("--model" , type = str , default = "azure-gpt-3.5" )
118- parser .add_argument ("--server-address" , type = str , default = "http://0.0.0.0:9094" )
128+ parser .add_argument ("--api-base" , type = str , default = None )
129+ parser .add_argument ("--api-key" , type = str , default = None )
119130 parser .add_argument ("--num-total-responses" , type = int , default = 50 )
120131 parser .add_argument ("--req-per-sec" , type = int , default = 5 )
121132 parser .add_argument ("--include-image" , action = "store_true" )
122133 parser .add_argument ("--randomize-image-dimensions" , action = "store_true" )
123134 args = parser .parse_args ()
124135
125- litellm_client = AsyncOpenAI (base_url = args .server_address , api_key = "sk-1234" )
126136 # Blank out contents of error_log.txt
127137 open ("load_test_errors.log" , "w" ).close ()
128138
129- asyncio . run ( main (args ) )
139+ main (args )
0 commit comments