|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | + |
| 3 | +""" |
| 4 | +Batch dimensions utilities. |
| 5 | +
|
| 6 | +This module contains utilities for managing batch dimensions, |
| 7 | +including the InferenceBatchDimensions dataclass and CUDAGraphBatchDimensionBuilder for generating |
| 8 | +and matching CUDA graph batch dimensions. |
| 9 | +""" |
| 10 | + |
| 11 | +import math |
| 12 | +from dataclasses import dataclass |
| 13 | +from typing import List, Optional, Tuple |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(order=True, frozen=True) |
| 17 | +class InferenceBatchDimensions: |
| 18 | + """Batch dimensions for dynamic inference. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + token_count : number of total input tokens |
| 22 | + prefill_req_count : number of prefill requests |
| 23 | + decode_req_count : number of decode requests |
| 24 | +
|
| 25 | + The batch dimensions are ordered by token_count, then by prefill_req_count, |
| 26 | + then by decode_req_count. |
| 27 | +
|
| 28 | + """ |
| 29 | + |
| 30 | + token_count: int = 0 |
| 31 | + prefill_req_count: int = 0 |
| 32 | + decode_req_count: int = 0 |
| 33 | + |
| 34 | + def __str__(self): |
| 35 | + """ |
| 36 | + Returns a string representation of the batch dimensions. |
| 37 | + """ |
| 38 | + return f"[{self.token_count}]: {self.prefill_req_count} P + {self.decode_req_count} D" |
| 39 | + |
| 40 | + def is_applicable_for_batch_dim( |
| 41 | + self, real_batch_dim: "InferenceBatchDimensions", strict: bool = False |
| 42 | + ) -> bool: |
| 43 | + """ |
| 44 | + Checks if this batch dimension is applicable for the given real batch dimension. |
| 45 | + Applicable batch dimensions are those that have enough tokens and |
| 46 | + requests budget to handle the real batch dimensions. |
| 47 | +
|
| 48 | + Note that if strict is False, prefill slots can be used |
| 49 | + for prefill or decode requests. Otherwise, prefill slots |
| 50 | + can only be used for prefill requests. |
| 51 | + """ |
| 52 | + if real_batch_dim.prefill_req_count == 0: |
| 53 | + return ( |
| 54 | + self.token_count >= real_batch_dim.token_count |
| 55 | + and self.decode_req_count >= real_batch_dim.decode_req_count |
| 56 | + and self.prefill_req_count == 0 # keep decode only property |
| 57 | + ) |
| 58 | + if strict: |
| 59 | + return ( |
| 60 | + self.token_count >= real_batch_dim.token_count |
| 61 | + and self.prefill_req_count >= real_batch_dim.prefill_req_count |
| 62 | + and self.decode_req_count >= real_batch_dim.decode_req_count |
| 63 | + ) |
| 64 | + else: |
| 65 | + return ( |
| 66 | + self.token_count >= real_batch_dim.token_count |
| 67 | + and self.prefill_req_count >= real_batch_dim.prefill_req_count |
| 68 | + and self.prefill_req_count + self.decode_req_count |
| 69 | + >= real_batch_dim.prefill_req_count + real_batch_dim.decode_req_count |
| 70 | + ) |
| 71 | + |
| 72 | + def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: |
| 73 | + """ |
| 74 | + Checks if the batch dimension is valid based on resource constraints. |
| 75 | +
|
| 76 | + Args: |
| 77 | + max_requests: Maximum number of requests allowed |
| 78 | +
|
| 79 | + Returns: |
| 80 | + True if the config is valid, False otherwise |
| 81 | + """ |
| 82 | + # Check if total requests exceed maximum |
| 83 | + if self.prefill_req_count + self.decode_req_count > max_requests: |
| 84 | + return False |
| 85 | + |
| 86 | + # Check for negative request counts |
| 87 | + if self.prefill_req_count < 0 or self.decode_req_count < 0: |
| 88 | + return False |
| 89 | + |
| 90 | + # Check if token count is sufficient for requests |
| 91 | + if self.token_count < self.prefill_req_count + self.decode_req_count: |
| 92 | + return False |
| 93 | + |
| 94 | + # Check if the prefill requests are shorter than the max sequence length |
| 95 | + if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count: |
| 96 | + return False |
| 97 | + |
| 98 | + return True |
| 99 | + |
| 100 | + def __hash__(self): |
| 101 | + """ |
| 102 | + Returns a hash of the batch dimension. |
| 103 | + In cuda graph quick matching, the batch dimension is used as a key in a dictionary. |
| 104 | + """ |
| 105 | + return hash((self.token_count, self.prefill_req_count, self.decode_req_count)) |
| 106 | + |
| 107 | + def __eq__(self, other: "InferenceBatchDimensions") -> bool: |
| 108 | + """ |
| 109 | + Checks if this batch dimension is equal to another batch dimension. |
| 110 | + """ |
| 111 | + if other is None: |
| 112 | + return False |
| 113 | + return (self.token_count, self.prefill_req_count, self.decode_req_count) == ( |
| 114 | + other.token_count, |
| 115 | + other.prefill_req_count, |
| 116 | + other.decode_req_count, |
| 117 | + ) |
| 118 | + |
| 119 | + @property |
| 120 | + def req_count(self) -> int: |
| 121 | + """ |
| 122 | + Returns the total number of requests. |
| 123 | + """ |
| 124 | + return self.prefill_req_count + self.decode_req_count |
| 125 | + |
| 126 | + |
| 127 | +class CUDAGraphBatchDimensionBuilder: |
| 128 | + """Builder for creating and managing CUDA graph batch dimensions. |
| 129 | +
|
| 130 | + This class provides static methods for generating lists of CUDA graph batch dimensions |
| 131 | + and matching the best batch dimension for a given real batch dimension. |
| 132 | + """ |
| 133 | + |
| 134 | + # Constant for rounding token counts when generating CUDA graph batch dimensions |
| 135 | + CUDA_GRAPH_ROUNDER = 8 |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def generate_cuda_graph_batch_dimensions_list( |
| 139 | + tp_size: int, |
| 140 | + num_cuda_graphs: Optional[int], |
| 141 | + cuda_graph_max_tokens: int, |
| 142 | + cuda_graph_mixed_prefill_count: Optional[int], |
| 143 | + max_requests: int, |
| 144 | + max_tokens: int, |
| 145 | + max_sequence_length: int, |
| 146 | + use_cuda_graphs_for_non_decode_steps: bool, |
| 147 | + ) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]: |
| 148 | + """ |
| 149 | + Generate CUDA graph batch dimensions. |
| 150 | +
|
| 151 | + This function constructs CUDA graph batch dimensions for different token counts |
| 152 | + and request patterns, then filters them based on resource constraints. |
| 153 | + The construction process involves: |
| 154 | +
|
| 155 | + Construction Rules: |
| 156 | + 1. Token count generation: Creates token counts from step_size to max_tokens, |
| 157 | + rounded to multiples of 8 |
| 158 | + 2. Tensor parallelism alignment: Ensures step_size is divisible by tensor parallel size |
| 159 | + 3. Batch dimension creation: For each token count, creates three types of batch dimensions: |
| 160 | + - Decode-only: (token_count, 0, token_count) - all tokens used for decode requests |
| 161 | + - Mixed prefill+decode: (token_count, prefill_req_count, token_count - prefill_req_count) |
| 162 | + - Prefill-only: |
| 163 | + (token_count, max(prefill_req_count, ceil(token_count/(max_seq_len-1))), 0) |
| 164 | +
|
| 165 | + Filtering Rules: |
| 166 | + 1. Request limit: prefill_req_count + decode_req_count <= max_requests |
| 167 | + 2. Non-negative counts: Both prefill_req_count and decode_req_count must be >= 0 |
| 168 | + 3. Token sufficiency: token_count >= prefill_req_count + decode_req_count |
| 169 | +
|
| 170 | + Sorting Rules for Attention Metadata Construction: |
| 171 | + 1. Batch dimensions are sorted by prefill token count (token_count - decode_req_count) |
| 172 | + in descending order |
| 173 | +
|
| 174 | + Args: |
| 175 | + tp_size: Tensor parallel size |
| 176 | + num_cuda_graphs: Number of CUDA graphs to generate |
| 177 | + cuda_graph_max_tokens: Maximum tokens for CUDA graphs |
| 178 | + cuda_graph_mixed_prefill_count: Number of mixed prefill requests for CUDA graphs |
| 179 | + max_requests: Maximum number of requests |
| 180 | + max_tokens: Maximum total tokens |
| 181 | + max_sequence_length: Maximum sequence length |
| 182 | + use_cuda_graphs_for_non_decode_steps: Whether to use CUDA graphs for non-decode steps |
| 183 | +
|
| 184 | + Returns: |
| 185 | + Tuple containing: |
| 186 | + - List of InferenceBatchDimensions objects, |
| 187 | + sorted by prefill token count in descending order |
| 188 | + - Optional list of CUDA graph token counts |
| 189 | + """ |
| 190 | + |
| 191 | + def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int) -> None: |
| 192 | + """Helper to create and append batch dimension to list only if it's valid.""" |
| 193 | + batch_dim = InferenceBatchDimensions(token_count, prefill_req_count, decode_req_count) |
| 194 | + if batch_dim.is_valid(max_requests, max_sequence_length): |
| 195 | + cuda_graph_batch_dimensions_list.append(batch_dim) |
| 196 | + |
| 197 | + # Cuda graph token-counts |
| 198 | + # (i.e., token counts used by cuda-graph steps, both decode and non-decode). |
| 199 | + cuda_graph_token_counts = None |
| 200 | + if num_cuda_graphs is not None: |
| 201 | + |
| 202 | + # Ensure valid num_cuda_graphs. |
| 203 | + if ( |
| 204 | + cuda_graph_max_tokens is None |
| 205 | + or cuda_graph_max_tokens > max_tokens |
| 206 | + or cuda_graph_max_tokens <= 0 |
| 207 | + ): |
| 208 | + cuda_graph_max_tokens = max_tokens |
| 209 | + num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens) |
| 210 | + |
| 211 | + # Cuda graph step size. |
| 212 | + cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs |
| 213 | + cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( |
| 214 | + math.ceil( |
| 215 | + int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER |
| 216 | + ) |
| 217 | + ) |
| 218 | + # Make sure divisible by TP size |
| 219 | + cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size |
| 220 | + |
| 221 | + # Cuda graph token counts. |
| 222 | + if num_cuda_graphs == 1: |
| 223 | + cuda_graph_token_counts = [cuda_graph_max_tokens] |
| 224 | + else: |
| 225 | + cuda_graph_token_counts = list( |
| 226 | + range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) |
| 227 | + ) |
| 228 | + if ( |
| 229 | + len(cuda_graph_token_counts) == 0 |
| 230 | + or cuda_graph_token_counts[-1] != cuda_graph_max_tokens |
| 231 | + ): |
| 232 | + cuda_graph_token_counts.append(cuda_graph_max_tokens) |
| 233 | + cuda_graph_token_counts.reverse() |
| 234 | + |
| 235 | + cuda_graph_batch_dimensions_list = [] |
| 236 | + if num_cuda_graphs is None: |
| 237 | + cuda_graph_batch_dimensions_list = [] |
| 238 | + elif ( |
| 239 | + not cuda_graph_mixed_prefill_count |
| 240 | + or cuda_graph_mixed_prefill_count <= 0 |
| 241 | + or not use_cuda_graphs_for_non_decode_steps |
| 242 | + ): # decode only |
| 243 | + for size in cuda_graph_token_counts: |
| 244 | + add_if_valid( |
| 245 | + token_count=min(size, max_requests), |
| 246 | + prefill_req_count=0, |
| 247 | + decode_req_count=min(size, max_requests), |
| 248 | + ) |
| 249 | + else: |
| 250 | + for size in cuda_graph_token_counts: |
| 251 | + add_if_valid( |
| 252 | + token_count=min(size, max_requests), |
| 253 | + prefill_req_count=0, |
| 254 | + decode_req_count=min(size, max_requests), |
| 255 | + ) |
| 256 | + add_if_valid( |
| 257 | + token_count=size, |
| 258 | + prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests), |
| 259 | + decode_req_count=min(size, max_requests) |
| 260 | + - min(cuda_graph_mixed_prefill_count, max_requests), |
| 261 | + ) |
| 262 | + # We need to ensure the prefill requests are shorter than the max sequence length, |
| 263 | + # considering the one decode token is used for prefill request construction |
| 264 | + prefill_only_minimal_num = max( |
| 265 | + cuda_graph_mixed_prefill_count, |
| 266 | + math.ceil(size / max(1, max_sequence_length - 1)), |
| 267 | + ) |
| 268 | + if prefill_only_minimal_num < max_requests: |
| 269 | + add_if_valid( |
| 270 | + token_count=size, |
| 271 | + prefill_req_count=max(prefill_only_minimal_num, min(max_requests, size)), |
| 272 | + decode_req_count=0, |
| 273 | + ) |
| 274 | + |
| 275 | + # Remove duplicates and sort by prefill token count |
| 276 | + cuda_graph_batch_dimensions_list = list(set(cuda_graph_batch_dimensions_list)) |
| 277 | + cuda_graph_batch_dimensions_list.sort( |
| 278 | + key=lambda x: ((x.token_count - x.decode_req_count), x.decode_req_count), reverse=True |
| 279 | + ) |
| 280 | + |
| 281 | + return cuda_graph_batch_dimensions_list, cuda_graph_token_counts |
| 282 | + |
| 283 | + @staticmethod |
| 284 | + def match_graph_config( |
| 285 | + real_batch_dim: InferenceBatchDimensions, |
| 286 | + cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], |
| 287 | + strict: bool = False, |
| 288 | + ) -> Optional[InferenceBatchDimensions]: |
| 289 | + """ |
| 290 | + Matches the best CUDA graph batch dimension for the given real batch dimension. |
| 291 | +
|
| 292 | + Args: |
| 293 | + real_batch_dim: The real batch dimension to match |
| 294 | + cuda_graph_batch_dimensions_list: List of available CUDA graph batch dimensions |
| 295 | + strict: If False, prefill slots can be used for prefill or decode requests. |
| 296 | + If True, prefill slots can only be used for prefill requests. |
| 297 | +
|
| 298 | + Returns: |
| 299 | + The best matching CUDA graph batch dimension, or None if no applicable match is found |
| 300 | + """ |
| 301 | + # first filter out batch dimensions with smaller token count, prefill req count, |
| 302 | + # or decode req count, as they are not applicable |
| 303 | + graph_batch_dims_applicable = [ |
| 304 | + graph_batch_dim |
| 305 | + for graph_batch_dim in cuda_graph_batch_dimensions_list |
| 306 | + if graph_batch_dim.is_applicable_for_batch_dim(real_batch_dim, strict=strict) |
| 307 | + ] |
| 308 | + if len(graph_batch_dims_applicable) == 0: |
| 309 | + return None |
| 310 | + # then find the best batch dimension |
| 311 | + best_batch_dim = min(graph_batch_dims_applicable) |
| 312 | + return best_batch_dim |
0 commit comments