Skip to content

Commit 69520bc

Browse files
authored
Add logging for cudagraph related info (vllm-project#29825)
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 3a77514 commit 69520bc

File tree

9 files changed

+161
-6
lines changed

9 files changed

+161
-6
lines changed

vllm/compilation/cuda_graph.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import dataclasses
5+
from collections import Counter
56
from collections.abc import Callable
67
from contextlib import ExitStack
78
from typing import Any
@@ -22,6 +23,99 @@
2223
logger = init_logger(__name__)
2324

2425

26+
@dataclasses.dataclass(frozen=True)
27+
class CUDAGraphStat:
28+
num_unpadded_tokens: int
29+
num_padded_tokens: int
30+
num_paddings: int
31+
runtime_mode: str
32+
33+
34+
class CUDAGraphLogging:
35+
"""Aggregate and log cudagraph metrics"""
36+
37+
COLUMN_HEADERS = [
38+
"Unpadded Tokens",
39+
"Padded Tokens",
40+
"Num Paddings",
41+
"Runtime Mode",
42+
"Count",
43+
]
44+
45+
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
46+
self.reset()
47+
self.cg_mode = str(cg_mode)
48+
self.cg_capture_sizes = str(cg_capture_sizes or [])
49+
50+
self.settings_header = (
51+
"**CUDAGraph Config Settings:**\n\n"
52+
f"- Mode: {self.cg_mode}\n"
53+
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
54+
"**CUDAGraph Stats:**\n\n"
55+
)
56+
57+
def reset(self):
58+
self.stats = []
59+
60+
def observe(self, cudagraph_stat: CUDAGraphStat):
61+
self.stats.append(cudagraph_stat)
62+
63+
def generate_metric_table(self) -> str:
64+
stats_counts = Counter(self.stats)
65+
66+
# Convert stats to rows of strings, in descending order of observed frequencies
67+
rows = []
68+
for stat, count in sorted(
69+
stats_counts.items(), key=lambda item: item[1], reverse=True
70+
):
71+
rows.append(
72+
[
73+
str(stat.num_unpadded_tokens),
74+
str(stat.num_padded_tokens),
75+
str(stat.num_paddings),
76+
stat.runtime_mode,
77+
str(count),
78+
]
79+
)
80+
81+
# Calculate column widths (max of header and data)
82+
col_widths = []
83+
for i, header_text in enumerate(self.COLUMN_HEADERS):
84+
max_width = len(header_text)
85+
for row in rows:
86+
max_width = max(max_width, len(row[i]))
87+
col_widths.append(max_width)
88+
89+
table_header_list = [
90+
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
91+
]
92+
table_header = "| " + " | ".join(table_header_list) + " |\n"
93+
94+
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
95+
96+
# Create data rows with proper alignment
97+
data_rows = []
98+
for row in rows:
99+
formatted_row = [
100+
str(val).ljust(width) for val, width in zip(row, col_widths)
101+
]
102+
data_rows.append("| " + " | ".join(formatted_row) + " |")
103+
104+
return (
105+
self.settings_header
106+
+ table_header
107+
+ table_separator
108+
+ "\n".join(data_rows)
109+
+ "\n"
110+
)
111+
112+
def log(self, log_fn=logger.info):
113+
if not self.stats:
114+
return
115+
log_fn(self.generate_metric_table())
116+
self.reset()
117+
118+
25119
@dataclasses.dataclass
26120
class CUDAGraphEntry:
27121
batch_descriptor: BatchDescriptor

vllm/config/observability.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def show_hidden_metrics(self) -> bool:
5555
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
5656
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
5757

58+
cudagraph_metrics: bool = False
59+
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
60+
dispatch modes, and their observed frequencies at every logging interval)."""
61+
5862
@cached_property
5963
def collect_model_forward_time(self) -> bool:
6064
"""Whether to collect model forward time for the request."""

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class EngineArgs:
518518
kv_cache_metrics_sample: float = get_field(
519519
ObservabilityConfig, "kv_cache_metrics_sample"
520520
)
521+
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
521522
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
522523
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
523524

@@ -1021,6 +1022,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10211022
"--kv-cache-metrics-sample",
10221023
**observability_kwargs["kv_cache_metrics_sample"],
10231024
)
1025+
observability_group.add_argument(
1026+
"--cudagraph-metrics",
1027+
**observability_kwargs["cudagraph_metrics"],
1028+
)
10241029

10251030
# Scheduler arguments
10261031
scheduler_kwargs = get_kwargs(SchedulerConfig)
@@ -1698,6 +1703,7 @@ def create_engine_config(
16981703
collect_detailed_traces=self.collect_detailed_traces,
16991704
kv_cache_metrics=self.kv_cache_metrics,
17001705
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1706+
cudagraph_metrics=self.cudagraph_metrics,
17011707
)
17021708

17031709
# Compilation config overrides

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from vllm import envs
10+
from vllm.compilation.cuda_graph import CUDAGraphStat
1011
from vllm.config import VllmConfig
1112
from vllm.distributed.ec_transfer.ec_connector.base import (
1213
ECConnectorMetadata,
@@ -1037,6 +1038,7 @@ def update_from_output(
10371038
pooler_outputs = model_runner_output.pooler_output
10381039
num_nans_in_logits = model_runner_output.num_nans_in_logits
10391040
kv_connector_output = model_runner_output.kv_connector_output
1041+
cudagraph_stats = model_runner_output.cudagraph_stats
10401042

10411043
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
10421044
spec_decoding_stats: SpecDecodingStats | None = None
@@ -1219,7 +1221,9 @@ def update_from_output(
12191221
finished_req_ids.clear()
12201222

12211223
if (
1222-
stats := self.make_stats(spec_decoding_stats, kv_connector_stats)
1224+
stats := self.make_stats(
1225+
spec_decoding_stats, kv_connector_stats, cudagraph_stats
1226+
)
12231227
) is not None:
12241228
# Return stats to only one of the front-ends.
12251229
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
@@ -1420,6 +1424,7 @@ def make_stats(
14201424
self,
14211425
spec_decoding_stats: SpecDecodingStats | None = None,
14221426
kv_connector_stats: KVConnectorStats | None = None,
1427+
cudagraph_stats: CUDAGraphStat | None = None,
14231428
) -> SchedulerStats | None:
14241429
if not self.log_stats:
14251430
return None
@@ -1444,6 +1449,7 @@ def make_stats(
14441449
kv_cache_eviction_events=eviction_events,
14451450
spec_decoding_stats=spec_stats,
14461451
kv_connector_stats=connector_stats_payload,
1452+
cudagraph_stats=cudagraph_stats,
14471453
)
14481454

14491455
def make_spec_decoding_stats(

vllm/v1/metrics/loggers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from prometheus_client import Counter, Gauge, Histogram
1111

1212
import vllm.envs as envs
13+
from vllm.compilation.cuda_graph import CUDAGraphLogging
1314
from vllm.config import SupportsMetricsInfo, VllmConfig
1415
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
1516
KVConnectorLogging,
@@ -106,6 +107,12 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
106107
self.spec_decoding_logging = SpecDecodingLogging()
107108
kv_transfer_config = self.vllm_config.kv_transfer_config
108109
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
110+
self.cudagraph_logging = None
111+
if self.vllm_config.observability_config.cudagraph_metrics:
112+
self.cudagraph_logging = CUDAGraphLogging(
113+
self.vllm_config.compilation_config.cudagraph_mode,
114+
self.vllm_config.compilation_config.cudagraph_capture_sizes,
115+
)
109116
self.last_prompt_throughput: float = 0.0
110117
self.last_generation_throughput: float = 0.0
111118
self.engine_is_idle = False
@@ -161,6 +168,11 @@ def record(
161168
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
162169
if kv_connector_stats := scheduler_stats.kv_connector_stats:
163170
self.kv_connector_logging.observe(kv_connector_stats)
171+
if (
172+
self.cudagraph_logging is not None
173+
and scheduler_stats.cudagraph_stats is not None
174+
):
175+
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
164176
if not self.aggregated:
165177
self.last_scheduler_stats = scheduler_stats
166178
if mm_cache_stats:
@@ -240,6 +252,8 @@ def log(self):
240252

241253
self.spec_decoding_logging.log(log_fn=log_fn)
242254
self.kv_connector_logging.log(log_fn=log_fn)
255+
if self.cudagraph_logging is not None:
256+
self.cudagraph_logging.log(log_fn=log_fn)
243257

244258
def log_engine_initialized(self):
245259
if self.vllm_config.cache_config.num_gpu_blocks:

vllm/v1/metrics/stats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any
88

99
import vllm.envs as envs
10+
from vllm.compilation.cuda_graph import CUDAGraphStat
1011
from vllm.v1.spec_decode.metrics import SpecDecodingStats
1112

1213
if TYPE_CHECKING:
@@ -183,6 +184,8 @@ class SchedulerStats:
183184
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
184185
running_lora_adapters: dict[str, int] = field(default_factory=dict)
185186

187+
cudagraph_stats: CUDAGraphStat | None = None
188+
186189

187190
@dataclass
188191
class RequestStateStats:

vllm/v1/outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010

11+
from vllm.compilation.cuda_graph import CUDAGraphStat
1112
from vllm.v1.core.sched.output import SchedulerOutput
1213

1314
if TYPE_CHECKING:
@@ -169,6 +170,9 @@ class ModelRunnerOutput:
169170
# req_id -> num_nans_in_logits
170171
num_nans_in_logits: dict[str, int] | None = None
171172

173+
# information related to cudagraph execution
174+
cudagraph_stats: CUDAGraphStat | None = None
175+
172176

173177
# ModelRunnerOutput wrapper for async scheduling.
174178
class AsyncModelRunnerOutput(ABC):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from vllm.attention.layer import Attention, MLAAttention
2929
from vllm.compilation.counter import compilation_counter
30-
from vllm.compilation.cuda_graph import CUDAGraphWrapper
30+
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
3131
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
3232
from vllm.config import (
3333
CompilationMode,
@@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
257257
sample_hidden_states: torch.Tensor
258258
aux_hidden_states: list[torch.Tensor] | None
259259
ec_connector_output: ECConnectorOutput | None
260+
cudagraph_stats: CUDAGraphStat | None
260261

261262

262263
class GPUModelRunner(
@@ -2755,7 +2756,11 @@ def _determine_batch_execution_and_padding(
27552756
force_uniform_decode: bool | None = None,
27562757
force_has_lora: bool | None = None,
27572758
) -> tuple[
2758-
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
2759+
CUDAGraphMode,
2760+
BatchDescriptor,
2761+
UBatchSlices | None,
2762+
torch.Tensor | None,
2763+
CUDAGraphStat | None,
27592764
]:
27602765
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
27612766
uniform_decode = (
@@ -2820,7 +2825,22 @@ def _determine_batch_execution_and_padding(
28202825
# num_tokens_across_dp will no-longer be valid
28212826
assert batch_descriptor.num_tokens == num_tokens_padded
28222827

2823-
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
2828+
cudagraph_stats = None
2829+
if self.vllm_config.observability_config.cudagraph_metrics:
2830+
cudagraph_stats = CUDAGraphStat(
2831+
num_unpadded_tokens=num_tokens,
2832+
num_padded_tokens=batch_descriptor.num_tokens,
2833+
num_paddings=batch_descriptor.num_tokens - num_tokens,
2834+
runtime_mode=str(cudagraph_mode),
2835+
)
2836+
2837+
return (
2838+
cudagraph_mode,
2839+
batch_descriptor,
2840+
ubatch_slices,
2841+
num_tokens_across_dp,
2842+
cudagraph_stats,
2843+
)
28242844

28252845
@torch.inference_mode()
28262846
def execute_model(
@@ -2918,6 +2938,7 @@ def execute_model(
29182938
batch_desc,
29192939
ubatch_slices,
29202940
num_tokens_across_dp,
2941+
cudagraph_stats,
29212942
) = self._determine_batch_execution_and_padding(
29222943
num_tokens=num_tokens_unpadded,
29232944
num_reqs=num_reqs,
@@ -3067,6 +3088,7 @@ def execute_model(
30673088
sample_hidden_states,
30683089
aux_hidden_states,
30693090
ec_connector_output,
3091+
cudagraph_stats,
30703092
)
30713093
self.kv_connector_output = kv_connector_output
30723094
return None
@@ -3102,6 +3124,7 @@ def sample_tokens(
31023124
sample_hidden_states,
31033125
aux_hidden_states,
31043126
ec_connector_output,
3127+
cudagraph_stats,
31053128
) = self.execute_model_state
31063129
# Clear ephemeral state.
31073130
self.execute_model_state = None
@@ -3217,6 +3240,7 @@ def propose_draft_token_ids(sampled_token_ids):
32173240
if self.supports_mm_inputs
32183241
else None,
32193242
num_nans_in_logits=num_nans_in_logits,
3243+
cudagraph_stats=cudagraph_stats,
32203244
)
32213245

32223246
if not self.use_async_scheduling:
@@ -3937,7 +3961,7 @@ def _dummy_run(
39373961

39383962
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
39393963

3940-
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
3964+
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
39413965
self._determine_batch_execution_and_padding(
39423966
num_tokens=num_tokens_unpadded,
39433967
num_reqs=num_reqs,

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def execute_model(
564564
# TODO(lucas): This is pretty gross; ideally we should only ever call
565565
# `_determine_batch_execution_and_padding` once (will get called again
566566
# in `execute_model`) but this requires a larger refactor of PP.
567-
_, batch_desc, _, _ = (
567+
_, batch_desc, _, _, _ = (
568568
self.model_runner._determine_batch_execution_and_padding(
569569
num_tokens=num_scheduled_tokens,
570570
num_reqs=len(num_scheduled_tokens_np),

0 commit comments

Comments
 (0)