Skip to content

Commit 1cd4b74

Browse files
authored
Update load test, use threads instead of asyncio (#3628)
1 parent 691be9d commit 1cd4b74

File tree

1 file changed

+47
-37
lines changed

1 file changed

+47
-37
lines changed

tests/load_test.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import argparse
2-
import time, asyncio
3-
from openai import AsyncOpenAI, AsyncAzureOpenAI
2+
import time
3+
import threading
4+
from concurrent.futures import ThreadPoolExecutor
45
import uuid
56
import traceback
67
import numpy as np
78
from transformers import AutoTokenizer
9+
from litellm import completion
810

9-
# base_url - litellm proxy endpoint
10-
# api_key - litellm proxy api-key, is created proxy with auth
11-
litellm_client = None
1211

13-
14-
async def litellm_completion(args, tokenizer, image_url=None):
15-
# Your existing code for litellm_completion goes here
12+
def litellm_completion(args, tokenizer, image_url=None):
1613
try:
1714
if image_url:
1815
messages = [
@@ -30,16 +27,24 @@ async def litellm_completion(args, tokenizer, image_url=None):
3027
]
3128

3229
start = time.time()
33-
response = await litellm_client.chat.completions.create(
30+
31+
additional_api_kwargs = {}
32+
if args.api_key:
33+
additional_api_kwargs["api_key"] = args.api_key
34+
if args.api_base:
35+
additional_api_kwargs["api_base"] = args.api_base
36+
37+
response = completion(
3438
model=args.model,
3539
messages=messages,
3640
stream=True,
41+
**additional_api_kwargs,
3742
)
3843
ttft = None
3944

4045
itl_list = []
4146
content = ""
42-
async for chunk in response:
47+
for chunk in response:
4348
if chunk.choices[0].delta.content:
4449
end_time = time.time()
4550
if ttft is None:
@@ -52,43 +57,48 @@ async def litellm_completion(args, tokenizer, image_url=None):
5257
return content, ttft, itl_list
5358

5459
except Exception as e:
55-
# If there's an exception, log the error message
5660
print(e)
5761
with open("error_log.txt", "a") as error_log:
5862
error_log.write(f"Error during completion: {str(e)}\n")
5963
return str(e)
6064

6165

62-
async def main(args):
66+
def main(args):
6367
n = args.num_total_responses
6468
batch_size = args.req_per_sec # Requests per second
6569
start = time.time()
6670

67-
all_tasks = []
71+
all_results = []
6872
tokenizer = AutoTokenizer.from_pretrained("gpt2")
69-
for i in range(0, n, batch_size):
70-
batch = range(i, min(i + batch_size, n))
71-
for _ in batch:
72-
if args.include_image:
73-
# Generate a random dimension for the image
74-
if args.randomize_image_dimensions:
75-
y_dimension = np.random.randint(100, 1025)
73+
74+
with ThreadPoolExecutor(max_workers=batch_size) as executor:
75+
for i in range(0, n, batch_size):
76+
batch_futures = []
77+
batch = range(i, min(i + batch_size, n))
78+
79+
for _ in batch:
80+
if args.include_image:
81+
if args.randomize_image_dimensions:
82+
y_dimension = np.random.randint(100, 1025)
83+
else:
84+
y_dimension = 512
85+
image_url = f"https://placehold.co/1024x{y_dimension}/png"
86+
future = executor.submit(
87+
litellm_completion, args, tokenizer, image_url
88+
)
7689
else:
77-
y_dimension = 512
78-
image_url = f"https://placehold.co/1024x{y_dimension}/png"
79-
task = asyncio.create_task(
80-
litellm_completion(args, tokenizer, image_url)
81-
)
82-
else:
83-
task = asyncio.create_task(litellm_completion(args, tokenizer))
84-
all_tasks.append(task)
85-
if i + batch_size < n:
86-
await asyncio.sleep(1) # Wait 1 second before the next batch
87-
88-
all_completions = await asyncio.gather(*all_tasks)
90+
future = executor.submit(litellm_completion, args, tokenizer)
91+
batch_futures.append(future)
92+
93+
# Wait for batch to complete
94+
for future in batch_futures:
95+
all_results.append(future.result())
96+
97+
if i + batch_size < n:
98+
time.sleep(1) # Wait 1 second before next batch
8999

90100
successful_completions = [
91-
c for c in all_completions if isinstance(c, tuple) and len(c) == 3
101+
c for c in all_results if isinstance(c, tuple) and len(c) == 3
92102
]
93103
ttft_list = np.array([float(c[1]) for c in successful_completions])
94104
itl_list_flattened = np.array(
@@ -101,7 +111,7 @@ async def main(args):
101111

102112
# Write errors to error_log.txt
103113
with open("load_test_errors.log", "a") as error_log:
104-
for completion in all_completions:
114+
for completion in all_results:
105115
if isinstance(completion, str):
106116
error_log.write(completion + "\n")
107117

@@ -115,15 +125,15 @@ async def main(args):
115125
if __name__ == "__main__":
116126
parser = argparse.ArgumentParser()
117127
parser.add_argument("--model", type=str, default="azure-gpt-3.5")
118-
parser.add_argument("--server-address", type=str, default="http://0.0.0.0:9094")
128+
parser.add_argument("--api-base", type=str, default=None)
129+
parser.add_argument("--api-key", type=str, default=None)
119130
parser.add_argument("--num-total-responses", type=int, default=50)
120131
parser.add_argument("--req-per-sec", type=int, default=5)
121132
parser.add_argument("--include-image", action="store_true")
122133
parser.add_argument("--randomize-image-dimensions", action="store_true")
123134
args = parser.parse_args()
124135

125-
litellm_client = AsyncOpenAI(base_url=args.server_address, api_key="sk-1234")
126136
# Blank out contents of error_log.txt
127137
open("load_test_errors.log", "w").close()
128138

129-
asyncio.run(main(args))
139+
main(args)

0 commit comments

Comments
 (0)