Skip to content

Commit b5b40ee

Browse files
authored
[Bench] Defaults to aiohttp client, add ServerMetrics (#2527)
* [Bench] Defaults to aiohttp client * Add ServerMetrics to summary * Remove duplicate servermetric def
1 parent 9be4b92 commit b5b40ee

File tree

3 files changed

+142
-57
lines changed

3 files changed

+142
-57
lines changed

python/mlc_llm/bench/metrics.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
""" MLC LLM bench Metrics"""
22
import json
3-
from typing import Callable, Dict, List, Optional, Union
3+
from typing import Any, Callable, Dict, List, Optional, Union
44

55
from pydantic import BaseModel
66

@@ -12,6 +12,19 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15+
class ServerMetrics(BaseModel):
16+
"""The metrics from the server side."""
17+
18+
prompt_tokens: int
19+
prefill_tokens: int
20+
completion_tokens: int
21+
decode_tokens_per_s: float
22+
prefill_tokens_per_s: float
23+
end_to_end_latency_s: float
24+
inter_token_latency_s: float
25+
ttft_s: Optional[float] = None
26+
27+
1528
class Metrics(BaseModel):
1629
"""The list of metric keys"""
1730

@@ -21,6 +34,7 @@ class Metrics(BaseModel):
2134
inter_token_latency_s: float
2235
decode_tokens_per_s: float
2336
ttft: Optional[float] = None
37+
server_metrics: Optional[ServerMetrics] = None
2438

2539

2640
class MetricsProcessor:
@@ -87,13 +101,26 @@ def extract_metrics_from_request_records(
87101
assert prompt_tokens > 0 and completion_tokens >= 0, "Invalid prompt tokens"
88102
end_to_end_latency_s = metric.end_to_end_latency_s
89103
ttft = metric.ttft if metric.ttft is not None else 0
104+
server_metric = None
105+
if metric.server_metrics is not None:
106+
server_metric = ServerMetrics(
107+
prompt_tokens=metric.server_metrics["prompt_tokens"],
108+
prefill_tokens=metric.server_metrics["prefill_tokens"],
109+
completion_tokens=metric.server_metrics["completion_tokens"],
110+
decode_tokens_per_s=metric.server_metrics["decode_tokens_per_s"],
111+
prefill_tokens_per_s=metric.server_metrics["prefill_tokens_per_s"],
112+
end_to_end_latency_s=metric.server_metrics["end_to_end_latency_s"],
113+
inter_token_latency_s=metric.server_metrics["inter_token_latency_s"],
114+
ttft_s=metric.server_metrics["ttft_s"],
115+
)
90116
refined_metric = Metrics(
91117
inter_token_latency_s=end_to_end_latency_s / completion_tokens,
92-
decode_tokens_per_s=completion_tokens / (end_to_end_latency_s - ttft),
118+
decode_tokens_per_s=(completion_tokens - 1) / (end_to_end_latency_s - ttft),
93119
ttft=metric.ttft,
94120
end_to_end_latency_s=end_to_end_latency_s,
95121
prompt_tokens=prompt_tokens,
96122
completion_tokens=completion_tokens,
123+
server_metrics=server_metric,
97124
)
98125
result.append(refined_metric)
99126
return result
@@ -148,9 +175,7 @@ def criteria(metric: Metrics) -> bool:
148175
self.reset_metrics(filered_metrics)
149176
return filered_metrics
150177

151-
def generate_metrics_summary(
152-
self, start_time: float, end_time: float
153-
) -> Dict[str, Union[int, float]]:
178+
def generate_metrics_summary(self, start_time: float, end_time: float) -> Dict[str, Any]:
154179
"""
155180
Computes summary statistics across all metrics collected.
156181
@@ -170,16 +195,49 @@ def generate_metrics_summary(
170195
report : Dict
171196
A dictionary containing the summary statistics of the collected metrics.
172197
"""
173-
import pandas as pd # pylint: disable=import-outside-toplevel,import-error
174-
175198
if not self.all_metrics:
176199
return {}
177200

178-
metrics = self.all_metrics
179-
df = pd.DataFrame([metric.model_dump() for metric in metrics])
201+
# Generate the client metrics statistics
202+
report = self._compute_metrics_statistics(self.all_metrics)
203+
report["num_completed_requests"] = len(self.all_metrics)
204+
total_tokens = sum(metric.completion_tokens for metric in self.all_metrics)
205+
report["overall_output_throughput"] = total_tokens / (end_time - start_time)
206+
207+
# Generate the server metrics statistics
208+
server_metrics = [
209+
metric.server_metrics for metric in self.all_metrics if metric.server_metrics
210+
]
211+
server_report = self._compute_metrics_statistics(server_metrics)
212+
report["server_metrics"] = server_report
213+
214+
logger.info("Metrics Summary:\n%s", json.dumps(report, indent=4, default=str))
215+
return report
216+
217+
def _compute_metrics_statistics(self, metrics: List[Union[Metrics, ServerMetrics]]) -> Dict:
218+
"""
219+
Compute the statistics of the metrics.
220+
221+
Parameters
222+
----------
223+
metrics : List[Union[Metrics, ServerMetrics]]
224+
The list of metrics to get the statistics.
225+
226+
Returns
227+
-------
228+
report : Dict
229+
The statistics of the metrics.
230+
"""
231+
import pandas as pd # pylint: disable=import-outside-toplevel,import-error
180232

181233
report: Dict = {}
182-
for key, _ in Metrics.model_fields.items():
234+
if not metrics:
235+
return report
236+
237+
df = pd.DataFrame([metric.model_dump() for metric in metrics])
238+
for key, _ in metrics[0].model_fields.items():
239+
if key == "server_metrics":
240+
continue
183241
if key in df.columns:
184242
series = df[key].dropna()
185243
report[key] = {
@@ -192,11 +250,4 @@ def generate_metrics_summary(
192250
"max": series.max(),
193251
"stddev": series.std(),
194252
}
195-
196-
report["num_completed_requests"] = len(metrics)
197-
report["overall_output_throughput"] = df["completion_tokens"].sum() / (
198-
end_time - start_time
199-
)
200-
201-
logger.info("Metrics Summary:\n%s", json.dumps(report, indent=4, default=str))
202253
return report

python/mlc_llm/bench/prompts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def __init__(
5858
assert "prompt" in json_line, "The prompt field is required in the JSONL file."
5959
if "prompt_tokens" not in json_line:
6060
json_line["prompt_tokens"] = self._count_tokens(json_line["prompt"])
61-
self.prompts.append(json.loads(line))
62-
self.prompts = [json.loads(line) for line in file]
61+
self.prompts.append(json_line)
6362
else:
6463
if not prompts_path:
6564
prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore

python/mlc_llm/bench/request.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""MLC LLM Bench Request"""
22
import json
3+
import os
34
import time
45
from typing import Any, Dict, List, Optional
56

6-
import httpx
77
from openai import AsyncOpenAI
88
from pydantic import BaseModel
99
from typing_extensions import Self
@@ -24,9 +24,10 @@ class RequestRecords(BaseModel):
2424
output: str
2525
end_to_end_latency_s: float
2626
ttft: Optional[float] = None
27+
server_metrics: Optional[Dict] = None
2728

2829

29-
class OpenAIRequestSender:
30+
class OpenAIRequestSender: # pylint: disable=too-many-instance-attributes
3031
"""
3132
Manages the sending of requests to a specified API endpoint and gathers inference statistics.
3233
@@ -40,20 +41,27 @@ class OpenAIRequestSender:
4041
Specifies if streaming should be enabled, default is True.
4142
timeout : Optional[float]
4243
The maximum duration in seconds for each request, default is 180.
44+
client : Optional[Any]
45+
The client to use for sending requests.
46+
include_server_metrics : Optional[bool]
47+
Specifies if server metrics should be included, default is False.
4348
4449
Attributes
4550
----------
4651
stats : dict
4752
Statistics about the performance.
4853
"""
4954

50-
def __init__(
55+
def __init__( # pylint: disable=too-many-arguments
5156
self,
5257
host: Optional[str] = "127.0.0.1",
5358
port: Optional[int] = 8008,
5459
stream: Optional[bool] = None,
5560
timeout: Optional[float] = None,
61+
client: Optional[Any] = None,
62+
include_server_metrics: Optional[bool] = False,
5663
) -> None:
64+
import aiohttp # pylint: disable=import-outside-toplevel,import-error
5765
from transformers import ( # pylint: disable=import-outside-toplevel,import-error
5866
LlamaTokenizerFast,
5967
)
@@ -63,75 +71,102 @@ def __init__(
6371
self.tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
6472
self.prompt_generator = PromptsGenerator()
6573
self.request_records: List[RequestRecords] = []
66-
self.client = AsyncOpenAI(
67-
base_url=f"http://{host}:{port}/v1",
68-
api_key="None",
69-
http_client=httpx.AsyncClient(http2=True),
70-
)
74+
self.client = client if client else aiohttp.ClientSession()
75+
self.include_server_metrics = include_server_metrics
76+
self.url = f"http://{host}:{port}/v1/chat/completions"
77+
self.headers = {"Content-Type": "application/json"}
78+
if os.getenv("MLC_LLM_API_KEY"):
79+
self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}"
7180

7281
async def __aenter__(self) -> Self:
7382
return self
7483

7584
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
7685
await self.client.close()
7786

78-
async def __call__(self, params: Dict[str, Any] = None) -> None:
79-
"""
80-
Send a request to the deployed serving endpoint and collect request records.
81-
82-
Parameters
83-
----------
84-
params : Dict[str, Any]
85-
The parameters for the request.
86-
87-
Returns
88-
-------
89-
response : Union[Dict, None]
90-
The JSON response from the server or None if an error occurs.
91-
"""
87+
async def __call__( # pylint: disable=too-many-locals, too-many-branches, too-many-statements
88+
self, params: Dict[str, Any] = None
89+
) -> None:
9290
if "messages" not in params:
9391
prompt_tokens = 128
9492
if "prompt_tokens" in params:
9593
prompt_tokens = params["prompt_tokens"]
9694
else:
9795
logger.warning("A random prompt with %d tokens will be generated.", prompt_tokens)
98-
9996
prompt = self.prompt_generator.generate_prompt(prompt_tokens)
10097
params["messages"] = [{"role": "system", "content": prompt}]
10198
else:
102-
prompt = params["messages"][0]["content"]
99+
prompt = params["messages"][-1]["content"]
103100
chat_params = self._get_chat_completion_params(params)
104101
if "stream" not in chat_params:
105102
chat_params["stream"] = self.stream
106103
if "timeout" not in chat_params:
107104
chat_params["timeout"] = self.timeout
105+
if self.include_server_metrics:
106+
if "stream_options" not in chat_params:
107+
chat_params["stream_options"] = {"include_usage": True}
108+
else:
109+
chat_params["stream_options"]["include_usage"] = True
108110

109111
total_request_time = 0
110112
generated_text = ""
111113
ttft = None
112114
start_time = time.monotonic()
113-
# chat_params["stream_options"] = {"include_usage": True}
114-
response = await self.client.chat.completions.create(**chat_params)
115-
116-
if chat_params["stream"]:
117-
async for chunk in response:
118-
if chunk.usage:
119-
logger.info(
120-
"Server Metrics:\n%s", json.dumps(chunk.usage.extra, indent=4, default=str)
121-
)
122-
elif chunk.choices[0].delta.content is not None:
123-
if not ttft:
124-
ttft = time.monotonic() - start_time # type: ignore
125-
generated_text += chunk.choices[0].delta.content
115+
server_metrics = None
116+
117+
# AsyncOpenAI chat completion
118+
if isinstance(self.client, AsyncOpenAI):
119+
response = await self.client.chat.completions.create(**chat_params)
120+
if chat_params["stream"]:
121+
async for chunk in response:
122+
if chunk.usage:
123+
server_metrics = chunk.usage.extra
124+
elif chunk.choices[0].delta.content is not None:
125+
if not ttft:
126+
ttft = time.monotonic() - start_time # type: ignore
127+
generated_text += chunk.choices[0].delta.content
128+
else:
129+
generated_text = response.choices[0].message.content
126130
else:
127-
generated_text = response.choices[0].message.content
131+
try:
132+
async with self.client.post(
133+
self.url, json=chat_params, headers=self.headers
134+
) as response:
135+
if chat_params["stream"]:
136+
async for chunk in response.content:
137+
chunk = chunk.strip()
138+
if not chunk or chunk == b"\n":
139+
continue
140+
# Get rid of the prefix "data: " and suffix "\n"
141+
raw_data = chunk[6:].strip()
142+
if raw_data == b"[DONE]":
143+
continue
144+
data = json.loads(raw_data)
145+
if data["usage"] is not None:
146+
server_metrics = data["usage"]["extra"]
147+
if not data["choices"]:
148+
continue
149+
delta = data["choices"][0]["delta"]
150+
if delta.get("content", None):
151+
if not ttft:
152+
ttft = time.monotonic() - start_time
153+
154+
generated_text += delta["content"]
155+
else:
156+
data = await response.json()
157+
generated_text = data["choices"][0]["message"]["content"]
158+
except Exception as e: # pylint: disable=broad-except
159+
logger.error("Error sending request: %s", str(e))
160+
raise e
128161

129162
total_request_time = time.monotonic() - start_time # type: ignore
163+
130164
req_rec = RequestRecords(
131165
input=prompt,
132166
output=generated_text,
133167
end_to_end_latency_s=total_request_time,
134168
ttft=ttft,
169+
server_metrics=server_metrics,
135170
)
136171
self.request_records.append(req_rec)
137172

0 commit comments

Comments
 (0)