Skip to content

Commit 0634924

Browse files
authored
implement graph config (#2203)
1 parent 5ab6392 commit 0634924

File tree

11 files changed

+882
-377
lines changed

11 files changed

+882
-377
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from mamba_builders import mamba_builder
5656

5757
from megatron.core.utils import configure_nvtx_profiling
58+
import logging
5859

5960
import json
6061

@@ -196,6 +197,8 @@ def get_inference_context(
196197
use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs,
197198
use_flashinfer_fused_rope=args.use_flashinfer_fused_rope,
198199
unified_memory_level=args.inference_dynamic_batching_unified_memory_level,
200+
cuda_graph_max_tokens=args.inference_dynamic_batching_cuda_graph_max_tokens,
201+
cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count,
199202
metrics_writer=metrics_writer,
200203
)
201204

@@ -278,7 +281,7 @@ def run_inference(
278281
total_output_tokens = 0
279282
attempted_step_count = 0
280283
if args.cuda_graph_impl == "local":
281-
cuda_graph_request_count_map = {r:0 for r in engine.context.cuda_graph_request_counts}
284+
cuda_graph_request_count_map = {}
282285
else:
283286
cuda_graph_request_count_map = None
284287

@@ -354,7 +357,7 @@ def _add_request():
354357
# Record cuda_graph_request_count.
355358
cuda_graph_request_count = result["cuda_graph_request_count"]
356359
if args.cuda_graph_impl == "local" and cuda_graph_request_count is not None:
357-
cuda_graph_request_count_map[cuda_graph_request_count] += 1
360+
cuda_graph_request_count_map[cuda_graph_request_count] = cuda_graph_request_count_map.get(cuda_graph_request_count, 0) + 1
358361

359362
# Update requests.
360363
active_request_ids = result["active_request_ids"]
@@ -421,6 +424,10 @@ def main():
421424
if os.environ.get("NSIGHT_PREFIX"):
422425
torch.cuda.cudart().cudaProfilerStart()
423426

427+
level_str = os.getenv("LOG_LEVEL", "INFO").upper()
428+
level = getattr(logging, level_str, logging.INFO)
429+
logging.basicConfig(level=level, force=True)
430+
424431
configure_nvtx_profiling(True)
425432

426433
args = get_args()

examples/inference/gpt/utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,7 @@ def build_dynamic_engine_setup_prefix(
379379
"""
380380
# CUDA graph config
381381
if args.cuda_graph_impl == "local":
382-
cg_str = (
383-
"graphs "
384-
f"[{len(context.cuda_graph_token_counts)}] "
385-
f"{context.cuda_graph_token_counts[0]}:"
386-
f"{context.cuda_graph_token_counts[-1]}"
387-
)
382+
cg_str = f"graphs {len(context.cuda_graph_batch_dimensions_list)}"
388383
else:
389384
cg_str = "--"
390385

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
3+
"""
4+
Batch dimensions utilities.
5+
6+
This module contains utilities for managing batch dimensions,
7+
including the InferenceBatchDimensions dataclass and CUDAGraphBatchDimensionBuilder for generating
8+
and matching CUDA graph batch dimensions.
9+
"""
10+
11+
import math
12+
from dataclasses import dataclass
13+
from typing import List, Optional, Tuple
14+
15+
16+
@dataclass(order=True, frozen=True)
17+
class InferenceBatchDimensions:
18+
"""Batch dimensions for dynamic inference.
19+
20+
Attributes:
21+
token_count : number of total input tokens
22+
prefill_req_count : number of prefill requests
23+
decode_req_count : number of decode requests
24+
25+
The batch dimensions are ordered by token_count, then by prefill_req_count,
26+
then by decode_req_count.
27+
28+
"""
29+
30+
token_count: int = 0
31+
prefill_req_count: int = 0
32+
decode_req_count: int = 0
33+
34+
def __str__(self):
35+
"""
36+
Returns a string representation of the batch dimensions.
37+
"""
38+
return f"[{self.token_count}]: {self.prefill_req_count} P + {self.decode_req_count} D"
39+
40+
def is_applicable_for_batch_dim(
41+
self, real_batch_dim: "InferenceBatchDimensions", strict: bool = False
42+
) -> bool:
43+
"""
44+
Checks if this batch dimension is applicable for the given real batch dimension.
45+
Applicable batch dimensions are those that have enough tokens and
46+
requests budget to handle the real batch dimensions.
47+
48+
Note that if strict is False, prefill slots can be used
49+
for prefill or decode requests. Otherwise, prefill slots
50+
can only be used for prefill requests.
51+
"""
52+
if real_batch_dim.prefill_req_count == 0:
53+
return (
54+
self.token_count >= real_batch_dim.token_count
55+
and self.decode_req_count >= real_batch_dim.decode_req_count
56+
and self.prefill_req_count == 0 # keep decode only property
57+
)
58+
if strict:
59+
return (
60+
self.token_count >= real_batch_dim.token_count
61+
and self.prefill_req_count >= real_batch_dim.prefill_req_count
62+
and self.decode_req_count >= real_batch_dim.decode_req_count
63+
)
64+
else:
65+
return (
66+
self.token_count >= real_batch_dim.token_count
67+
and self.prefill_req_count >= real_batch_dim.prefill_req_count
68+
and self.prefill_req_count + self.decode_req_count
69+
>= real_batch_dim.prefill_req_count + real_batch_dim.decode_req_count
70+
)
71+
72+
def is_valid(self, max_requests: int, max_sequence_length: int) -> bool:
73+
"""
74+
Checks if the batch dimension is valid based on resource constraints.
75+
76+
Args:
77+
max_requests: Maximum number of requests allowed
78+
79+
Returns:
80+
True if the config is valid, False otherwise
81+
"""
82+
# Check if total requests exceed maximum
83+
if self.prefill_req_count + self.decode_req_count > max_requests:
84+
return False
85+
86+
# Check for negative request counts
87+
if self.prefill_req_count < 0 or self.decode_req_count < 0:
88+
return False
89+
90+
# Check if token count is sufficient for requests
91+
if self.token_count < self.prefill_req_count + self.decode_req_count:
92+
return False
93+
94+
# Check if the prefill requests are shorter than the max sequence length
95+
if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count:
96+
return False
97+
98+
return True
99+
100+
def __hash__(self):
101+
"""
102+
Returns a hash of the batch dimension.
103+
In cuda graph quick matching, the batch dimension is used as a key in a dictionary.
104+
"""
105+
return hash((self.token_count, self.prefill_req_count, self.decode_req_count))
106+
107+
def __eq__(self, other: "InferenceBatchDimensions") -> bool:
108+
"""
109+
Checks if this batch dimension is equal to another batch dimension.
110+
"""
111+
if other is None:
112+
return False
113+
return (self.token_count, self.prefill_req_count, self.decode_req_count) == (
114+
other.token_count,
115+
other.prefill_req_count,
116+
other.decode_req_count,
117+
)
118+
119+
@property
120+
def req_count(self) -> int:
121+
"""
122+
Returns the total number of requests.
123+
"""
124+
return self.prefill_req_count + self.decode_req_count
125+
126+
127+
class CUDAGraphBatchDimensionBuilder:
128+
"""Builder for creating and managing CUDA graph batch dimensions.
129+
130+
This class provides static methods for generating lists of CUDA graph batch dimensions
131+
and matching the best batch dimension for a given real batch dimension.
132+
"""
133+
134+
# Constant for rounding token counts when generating CUDA graph batch dimensions
135+
CUDA_GRAPH_ROUNDER = 8
136+
137+
@staticmethod
138+
def generate_cuda_graph_batch_dimensions_list(
139+
tp_size: int,
140+
num_cuda_graphs: Optional[int],
141+
cuda_graph_max_tokens: int,
142+
cuda_graph_mixed_prefill_count: Optional[int],
143+
max_requests: int,
144+
max_tokens: int,
145+
max_sequence_length: int,
146+
use_cuda_graphs_for_non_decode_steps: bool,
147+
) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]:
148+
"""
149+
Generate CUDA graph batch dimensions.
150+
151+
This function constructs CUDA graph batch dimensions for different token counts
152+
and request patterns, then filters them based on resource constraints.
153+
The construction process involves:
154+
155+
Construction Rules:
156+
1. Token count generation: Creates token counts from step_size to max_tokens,
157+
rounded to multiples of 8
158+
2. Tensor parallelism alignment: Ensures step_size is divisible by tensor parallel size
159+
3. Batch dimension creation: For each token count, creates three types of batch dimensions:
160+
- Decode-only: (token_count, 0, token_count) - all tokens used for decode requests
161+
- Mixed prefill+decode: (token_count, prefill_req_count, token_count - prefill_req_count)
162+
- Prefill-only:
163+
(token_count, max(prefill_req_count, ceil(token_count/(max_seq_len-1))), 0)
164+
165+
Filtering Rules:
166+
1. Request limit: prefill_req_count + decode_req_count <= max_requests
167+
2. Non-negative counts: Both prefill_req_count and decode_req_count must be >= 0
168+
3. Token sufficiency: token_count >= prefill_req_count + decode_req_count
169+
170+
Sorting Rules for Attention Metadata Construction:
171+
1. Batch dimensions are sorted by prefill token count (token_count - decode_req_count)
172+
in descending order
173+
174+
Args:
175+
tp_size: Tensor parallel size
176+
num_cuda_graphs: Number of CUDA graphs to generate
177+
cuda_graph_max_tokens: Maximum tokens for CUDA graphs
178+
cuda_graph_mixed_prefill_count: Number of mixed prefill requests for CUDA graphs
179+
max_requests: Maximum number of requests
180+
max_tokens: Maximum total tokens
181+
max_sequence_length: Maximum sequence length
182+
use_cuda_graphs_for_non_decode_steps: Whether to use CUDA graphs for non-decode steps
183+
184+
Returns:
185+
Tuple containing:
186+
- List of InferenceBatchDimensions objects,
187+
sorted by prefill token count in descending order
188+
- Optional list of CUDA graph token counts
189+
"""
190+
191+
def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int) -> None:
192+
"""Helper to create and append batch dimension to list only if it's valid."""
193+
batch_dim = InferenceBatchDimensions(token_count, prefill_req_count, decode_req_count)
194+
if batch_dim.is_valid(max_requests, max_sequence_length):
195+
cuda_graph_batch_dimensions_list.append(batch_dim)
196+
197+
# Cuda graph token-counts
198+
# (i.e., token counts used by cuda-graph steps, both decode and non-decode).
199+
cuda_graph_token_counts = None
200+
if num_cuda_graphs is not None:
201+
202+
# Ensure valid num_cuda_graphs.
203+
if (
204+
cuda_graph_max_tokens is None
205+
or cuda_graph_max_tokens > max_tokens
206+
or cuda_graph_max_tokens <= 0
207+
):
208+
cuda_graph_max_tokens = max_tokens
209+
num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens)
210+
211+
# Cuda graph step size.
212+
cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs
213+
cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int(
214+
math.ceil(
215+
int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER
216+
)
217+
)
218+
# Make sure divisible by TP size
219+
cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size
220+
221+
# Cuda graph token counts.
222+
if num_cuda_graphs == 1:
223+
cuda_graph_token_counts = [cuda_graph_max_tokens]
224+
else:
225+
cuda_graph_token_counts = list(
226+
range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size)
227+
)
228+
if (
229+
len(cuda_graph_token_counts) == 0
230+
or cuda_graph_token_counts[-1] != cuda_graph_max_tokens
231+
):
232+
cuda_graph_token_counts.append(cuda_graph_max_tokens)
233+
cuda_graph_token_counts.reverse()
234+
235+
cuda_graph_batch_dimensions_list = []
236+
if num_cuda_graphs is None:
237+
cuda_graph_batch_dimensions_list = []
238+
elif (
239+
not cuda_graph_mixed_prefill_count
240+
or cuda_graph_mixed_prefill_count <= 0
241+
or not use_cuda_graphs_for_non_decode_steps
242+
): # decode only
243+
for size in cuda_graph_token_counts:
244+
add_if_valid(
245+
token_count=min(size, max_requests),
246+
prefill_req_count=0,
247+
decode_req_count=min(size, max_requests),
248+
)
249+
else:
250+
for size in cuda_graph_token_counts:
251+
add_if_valid(
252+
token_count=min(size, max_requests),
253+
prefill_req_count=0,
254+
decode_req_count=min(size, max_requests),
255+
)
256+
add_if_valid(
257+
token_count=size,
258+
prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests),
259+
decode_req_count=min(size, max_requests)
260+
- min(cuda_graph_mixed_prefill_count, max_requests),
261+
)
262+
# We need to ensure the prefill requests are shorter than the max sequence length,
263+
# considering the one decode token is used for prefill request construction
264+
prefill_only_minimal_num = max(
265+
cuda_graph_mixed_prefill_count,
266+
math.ceil(size / max(1, max_sequence_length - 1)),
267+
)
268+
if prefill_only_minimal_num < max_requests:
269+
add_if_valid(
270+
token_count=size,
271+
prefill_req_count=max(prefill_only_minimal_num, min(max_requests, size)),
272+
decode_req_count=0,
273+
)
274+
275+
# Remove duplicates and sort by prefill token count
276+
cuda_graph_batch_dimensions_list = list(set(cuda_graph_batch_dimensions_list))
277+
cuda_graph_batch_dimensions_list.sort(
278+
key=lambda x: ((x.token_count - x.decode_req_count), x.decode_req_count), reverse=True
279+
)
280+
281+
return cuda_graph_batch_dimensions_list, cuda_graph_token_counts
282+
283+
@staticmethod
284+
def match_graph_config(
285+
real_batch_dim: InferenceBatchDimensions,
286+
cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions],
287+
strict: bool = False,
288+
) -> Optional[InferenceBatchDimensions]:
289+
"""
290+
Matches the best CUDA graph batch dimension for the given real batch dimension.
291+
292+
Args:
293+
real_batch_dim: The real batch dimension to match
294+
cuda_graph_batch_dimensions_list: List of available CUDA graph batch dimensions
295+
strict: If False, prefill slots can be used for prefill or decode requests.
296+
If True, prefill slots can only be used for prefill requests.
297+
298+
Returns:
299+
The best matching CUDA graph batch dimension, or None if no applicable match is found
300+
"""
301+
# first filter out batch dimensions with smaller token count, prefill req count,
302+
# or decode req count, as they are not applicable
303+
graph_batch_dims_applicable = [
304+
graph_batch_dim
305+
for graph_batch_dim in cuda_graph_batch_dimensions_list
306+
if graph_batch_dim.is_applicable_for_batch_dim(real_batch_dim, strict=strict)
307+
]
308+
if len(graph_batch_dims_applicable) == 0:
309+
return None
310+
# then find the best batch dimension
311+
best_batch_dim = min(graph_batch_dims_applicable)
312+
return best_batch_dim

0 commit comments

Comments
 (0)