1
1
"""MLC LLM Bench Request"""
2
2
import json
3
+ import os
3
4
import time
4
5
from typing import Any , Dict , List , Optional
5
6
6
- import httpx
7
7
from openai import AsyncOpenAI
8
8
from pydantic import BaseModel
9
9
from typing_extensions import Self
@@ -24,9 +24,10 @@ class RequestRecords(BaseModel):
24
24
output : str
25
25
end_to_end_latency_s : float
26
26
ttft : Optional [float ] = None
27
+ server_metrics : Optional [Dict ] = None
27
28
28
29
29
- class OpenAIRequestSender :
30
+ class OpenAIRequestSender : # pylint: disable=too-many-instance-attributes
30
31
"""
31
32
Manages the sending of requests to a specified API endpoint and gathers inference statistics.
32
33
@@ -40,20 +41,27 @@ class OpenAIRequestSender:
40
41
Specifies if streaming should be enabled, default is True.
41
42
timeout : Optional[float]
42
43
The maximum duration in seconds for each request, default is 180.
44
+ client : Optional[Any]
45
+ The client to use for sending requests.
46
+ include_server_metrics : Optional[bool]
47
+ Specifies if server metrics should be included, default is False.
43
48
44
49
Attributes
45
50
----------
46
51
stats : dict
47
52
Statistics about the performance.
48
53
"""
49
54
50
- def __init__ (
55
+ def __init__ ( # pylint: disable=too-many-arguments
51
56
self ,
52
57
host : Optional [str ] = "127.0.0.1" ,
53
58
port : Optional [int ] = 8008 ,
54
59
stream : Optional [bool ] = None ,
55
60
timeout : Optional [float ] = None ,
61
+ client : Optional [Any ] = None ,
62
+ include_server_metrics : Optional [bool ] = False ,
56
63
) -> None :
64
+ import aiohttp # pylint: disable=import-outside-toplevel,import-error
57
65
from transformers import ( # pylint: disable=import-outside-toplevel,import-error
58
66
LlamaTokenizerFast ,
59
67
)
@@ -63,75 +71,102 @@ def __init__(
63
71
self .tokenizer = LlamaTokenizerFast .from_pretrained ("hf-internal-testing/llama-tokenizer" )
64
72
self .prompt_generator = PromptsGenerator ()
65
73
self .request_records : List [RequestRecords ] = []
66
- self .client = AsyncOpenAI (
67
- base_url = f"http://{ host } :{ port } /v1" ,
68
- api_key = "None" ,
69
- http_client = httpx .AsyncClient (http2 = True ),
70
- )
74
+ self .client = client if client else aiohttp .ClientSession ()
75
+ self .include_server_metrics = include_server_metrics
76
+ self .url = f"http://{ host } :{ port } /v1/chat/completions"
77
+ self .headers = {"Content-Type" : "application/json" }
78
+ if os .getenv ("MLC_LLM_API_KEY" ):
79
+ self .headers ["Authorization" ] = f"Bearer { os .getenv ('MLC_LLM_API_KEY' )} "
71
80
72
81
async def __aenter__ (self ) -> Self :
73
82
return self
74
83
75
84
async def __aexit__ (self , exc_type , exc_value , traceback ) -> None :
76
85
await self .client .close ()
77
86
78
- async def __call__ (self , params : Dict [str , Any ] = None ) -> None :
79
- """
80
- Send a request to the deployed serving endpoint and collect request records.
81
-
82
- Parameters
83
- ----------
84
- params : Dict[str, Any]
85
- The parameters for the request.
86
-
87
- Returns
88
- -------
89
- response : Union[Dict, None]
90
- The JSON response from the server or None if an error occurs.
91
- """
87
+ async def __call__ ( # pylint: disable=too-many-locals, too-many-branches, too-many-statements
88
+ self , params : Dict [str , Any ] = None
89
+ ) -> None :
92
90
if "messages" not in params :
93
91
prompt_tokens = 128
94
92
if "prompt_tokens" in params :
95
93
prompt_tokens = params ["prompt_tokens" ]
96
94
else :
97
95
logger .warning ("A random prompt with %d tokens will be generated." , prompt_tokens )
98
-
99
96
prompt = self .prompt_generator .generate_prompt (prompt_tokens )
100
97
params ["messages" ] = [{"role" : "system" , "content" : prompt }]
101
98
else :
102
- prompt = params ["messages" ][0 ]["content" ]
99
+ prompt = params ["messages" ][- 1 ]["content" ]
103
100
chat_params = self ._get_chat_completion_params (params )
104
101
if "stream" not in chat_params :
105
102
chat_params ["stream" ] = self .stream
106
103
if "timeout" not in chat_params :
107
104
chat_params ["timeout" ] = self .timeout
105
+ if self .include_server_metrics :
106
+ if "stream_options" not in chat_params :
107
+ chat_params ["stream_options" ] = {"include_usage" : True }
108
+ else :
109
+ chat_params ["stream_options" ]["include_usage" ] = True
108
110
109
111
total_request_time = 0
110
112
generated_text = ""
111
113
ttft = None
112
114
start_time = time .monotonic ()
113
- # chat_params["stream_options"] = {"include_usage": True}
114
- response = await self .client .chat .completions .create (** chat_params )
115
-
116
- if chat_params ["stream" ]:
117
- async for chunk in response :
118
- if chunk .usage :
119
- logger .info (
120
- "Server Metrics:\n %s" , json .dumps (chunk .usage .extra , indent = 4 , default = str )
121
- )
122
- elif chunk .choices [0 ].delta .content is not None :
123
- if not ttft :
124
- ttft = time .monotonic () - start_time # type: ignore
125
- generated_text += chunk .choices [0 ].delta .content
115
+ server_metrics = None
116
+
117
+ # AsyncOpenAI chat completion
118
+ if isinstance (self .client , AsyncOpenAI ):
119
+ response = await self .client .chat .completions .create (** chat_params )
120
+ if chat_params ["stream" ]:
121
+ async for chunk in response :
122
+ if chunk .usage :
123
+ server_metrics = chunk .usage .extra
124
+ elif chunk .choices [0 ].delta .content is not None :
125
+ if not ttft :
126
+ ttft = time .monotonic () - start_time # type: ignore
127
+ generated_text += chunk .choices [0 ].delta .content
128
+ else :
129
+ generated_text = response .choices [0 ].message .content
126
130
else :
127
- generated_text = response .choices [0 ].message .content
131
+ try :
132
+ async with self .client .post (
133
+ self .url , json = chat_params , headers = self .headers
134
+ ) as response :
135
+ if chat_params ["stream" ]:
136
+ async for chunk in response .content :
137
+ chunk = chunk .strip ()
138
+ if not chunk or chunk == b"\n " :
139
+ continue
140
+ # Get rid of the prefix "data: " and suffix "\n"
141
+ raw_data = chunk [6 :].strip ()
142
+ if raw_data == b"[DONE]" :
143
+ continue
144
+ data = json .loads (raw_data )
145
+ if data ["usage" ] is not None :
146
+ server_metrics = data ["usage" ]["extra" ]
147
+ if not data ["choices" ]:
148
+ continue
149
+ delta = data ["choices" ][0 ]["delta" ]
150
+ if delta .get ("content" , None ):
151
+ if not ttft :
152
+ ttft = time .monotonic () - start_time
153
+
154
+ generated_text += delta ["content" ]
155
+ else :
156
+ data = await response .json ()
157
+ generated_text = data ["choices" ][0 ]["message" ]["content" ]
158
+ except Exception as e : # pylint: disable=broad-except
159
+ logger .error ("Error sending request: %s" , str (e ))
160
+ raise e
128
161
129
162
total_request_time = time .monotonic () - start_time # type: ignore
163
+
130
164
req_rec = RequestRecords (
131
165
input = prompt ,
132
166
output = generated_text ,
133
167
end_to_end_latency_s = total_request_time ,
134
168
ttft = ttft ,
169
+ server_metrics = server_metrics ,
135
170
)
136
171
self .request_records .append (req_rec )
137
172
0 commit comments