Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit dd53c4b

Browse files
[misc] Add Torch profiler support (vllm-project#7451)
Co-authored-by: Cody Yu <[email protected]>
1 parent 970dfdc commit dd53c4b

File tree

12 files changed

+191
-2
lines changed

12 files changed

+191
-2
lines changed

benchmarks/backend_request_func.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ async def async_request_openai_completions(
225225
) -> RequestFuncOutput:
226226
api_url = request_func_input.api_url
227227
assert api_url.endswith(
228-
"completions"
229-
), "OpenAI Completions API URL must end with 'completions'."
228+
("completions", "profile")
229+
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
230230

231231
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
232232
assert not request_func_input.use_beam_search

benchmarks/benchmark_serving.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,15 @@ def calculate_metrics(
295295
async def benchmark(
296296
backend: str,
297297
api_url: str,
298+
base_url: str,
298299
model_id: str,
299300
tokenizer: PreTrainedTokenizerBase,
300301
input_requests: List[Tuple[str, int, int]],
301302
best_of: int,
302303
use_beam_search: bool,
303304
request_rate: float,
304305
disable_tqdm: bool,
306+
profile: bool,
305307
):
306308
if backend in ASYNC_REQUEST_FUNCS:
307309
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -326,6 +328,22 @@ async def benchmark(
326328
f"are correctly specified. Error: {test_output.error}")
327329
else:
328330
print("Initial test run completed. Starting main benchmark run...")
331+
332+
if profile:
333+
print("Starting profiler...")
334+
profile_input = RequestFuncInput(
335+
model=model_id,
336+
prompt=test_prompt,
337+
api_url=base_url + "/start_profile",
338+
prompt_len=test_prompt_len,
339+
output_len=test_output_len,
340+
best_of=best_of,
341+
use_beam_search=use_beam_search,
342+
)
343+
profile_output = await request_func(request_func_input=profile_input)
344+
if profile_output.success:
345+
print("Profiler started")
346+
329347
print(f"Traffic request rate: {request_rate}")
330348

331349
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -349,6 +367,21 @@ async def benchmark(
349367
pbar=pbar)))
350368
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
351369

370+
if profile:
371+
print("Stopping profiler...")
372+
profile_input = RequestFuncInput(
373+
model=model_id,
374+
prompt=test_prompt,
375+
api_url=base_url + "/stop_profile",
376+
prompt_len=test_prompt_len,
377+
output_len=test_output_len,
378+
best_of=best_of,
379+
use_beam_search=use_beam_search,
380+
)
381+
profile_output = await request_func(request_func_input=profile_input)
382+
if profile_output.success:
383+
print("Profiler stopped")
384+
352385
if pbar is not None:
353386
pbar.close()
354387

@@ -433,8 +466,10 @@ def main(args: argparse.Namespace):
433466

434467
if args.base_url is not None:
435468
api_url = f"{args.base_url}{args.endpoint}"
469+
base_url = f"{args.base_url}"
436470
else:
437471
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
472+
base_url = f"http://{args.host}:{args.port}"
438473

439474
tokenizer = get_tokenizer(tokenizer_id,
440475
trust_remote_code=args.trust_remote_code)
@@ -506,13 +541,15 @@ def main(args: argparse.Namespace):
506541
benchmark(
507542
backend=backend,
508543
api_url=api_url,
544+
base_url=base_url,
509545
model_id=model_id,
510546
tokenizer=tokenizer,
511547
input_requests=input_requests,
512548
best_of=args.best_of,
513549
use_beam_search=args.use_beam_search,
514550
request_rate=args.request_rate,
515551
disable_tqdm=args.disable_tqdm,
552+
profile=args.profile,
516553
))
517554

518555
# Save config and results to json
@@ -693,6 +730,12 @@ def main(args: argparse.Namespace):
693730
action="store_true",
694731
help="Specify to disable tqdm progress bar.",
695732
)
733+
parser.add_argument(
734+
"--profile",
735+
action="store_true",
736+
help="Use Torch Profiler. The endpoint must be launched with "
737+
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
738+
)
696739
parser.add_argument(
697740
"--save-result",
698741
action="store_true",
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
Profiling vLLM
2+
=================================
3+
4+
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
5+
6+
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
7+
8+
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
9+
10+
.. warning::
11+
12+
Only enable profiling in a development environment.
13+
14+
15+
Traces can be visualized using https://ui.perfetto.dev/.
16+
17+
.. tip::
18+
19+
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
20+
21+
Example commands:
22+
23+
OpenAI Server:
24+
25+
.. code-block:: bash
26+
27+
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
28+
29+
benchmark_serving.py:
30+
31+
.. code-block:: bash
32+
33+
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ Documentation
136136
dev/input_processing/model_inputs_index
137137
dev/multimodal/multimodal_index
138138
dev/dockerfile/dockerfile
139+
dev/profiling/profiling_index
139140

140141
.. toctree::
141142
:maxdepth: 1

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,3 +1266,9 @@ def remove_logger(self, logger_name: str) -> None:
12661266
logger_name=logger_name))
12671267
else:
12681268
self.engine.remove_logger(logger_name=logger_name)
1269+
1270+
async def start_profile(self) -> None:
1271+
self.engine.model_executor._run_workers("start_profile")
1272+
1273+
async def stop_profile(self) -> None:
1274+
self.engine.model_executor._run_workers("stop_profile")

vllm/engine/protocol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,11 @@ async def do_log_stats(
9191
async def check_health(self) -> None:
9292
"""Raise if unhealthy"""
9393
...
94+
95+
async def start_profile(self) -> None:
96+
"""Start profiling the engine"""
97+
...
98+
99+
async def stop_profile(self) -> None:
100+
"""Start profiling the engine"""
101+
...

vllm/entrypoints/openai/api_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
305305
assert_never(generator)
306306

307307

308+
if envs.VLLM_TORCH_PROFILER_DIR:
309+
logger.warning(
310+
"Torch Profiler is enabled in the API server. This should ONLY be "
311+
"used for local development!")
312+
313+
@router.post("/start_profile")
314+
async def start_profile():
315+
logger.info("Starting profiler...")
316+
await async_engine_client.start_profile()
317+
logger.info("Profiler started.")
318+
return Response(status_code=200)
319+
320+
@router.post("/stop_profile")
321+
async def stop_profile():
322+
logger.info("Stopping profiler...")
323+
await async_engine_client.stop_profile()
324+
logger.info("Profiler stopped.")
325+
return Response(status_code=200)
326+
327+
308328
def build_app(args: Namespace) -> FastAPI:
309329
app = FastAPI(lifespan=lifespan)
310330
app.include_router(router)

vllm/entrypoints/openai/rpc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
4646
DO_LOG_STATS = 7
4747
IS_SERVER_HEALTHY = 8
4848
IS_TRACING_ENABLED = 9
49+
START_PROFILE = 10
50+
STOP_PROFILE = 11
4951

5052

5153
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,

vllm/entrypoints/openai/rpc/client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,17 @@ async def encode(self, *args,
400400
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
401401
raise NotImplementedError(
402402
"Embeddings not supported with multiprocessing backend")
403+
404+
async def start_profile(self) -> None:
405+
"""Start profiling the engine"""
406+
407+
await self._send_one_way_rpc_request(
408+
request=RPCUtilityRequest.START_PROFILE,
409+
error_message="RPCRequest START_PROFILE failed.")
410+
411+
async def stop_profile(self) -> None:
412+
"""Stop profiling the engine"""
413+
414+
await self._send_one_way_rpc_request(
415+
request=RPCUtilityRequest.STOP_PROFILE,
416+
error_message="RPCRequest STOP_PROFILE failed.")

vllm/entrypoints/openai/rpc/server.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ async def check_health(self, identity):
124124
except Exception as e:
125125
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
126126

127+
async def start_profile(self, identity):
128+
logger.info("Starting profiler...")
129+
await self.engine.start_profile()
130+
logger.info("Profiler started.")
131+
132+
await self.socket.send_multipart([
133+
identity,
134+
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
135+
])
136+
137+
async def stop_profile(self, identity):
138+
logger.info("Stopping profiler...")
139+
await self.engine.stop_profile()
140+
logger.info("Profiler stopped.")
141+
142+
await self.socket.send_multipart([
143+
identity,
144+
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
145+
])
146+
127147
def _make_handler_coro(self, identity,
128148
message) -> Coroutine[Any, Any, Never]:
129149
"""Route the zmq message to the handler coroutine."""
@@ -153,6 +173,10 @@ def _make_handler_coro(self, identity,
153173
return self.check_health(identity)
154174
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
155175
return self.is_tracing_enabled(identity)
176+
elif request == RPCUtilityRequest.START_PROFILE:
177+
return self.start_profile(identity)
178+
elif request == RPCUtilityRequest.STOP_PROFILE:
179+
return self.stop_profile(identity)
156180
else:
157181
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
158182

0 commit comments

Comments
 (0)