|
19 | 19 | from collections import defaultdict |
20 | 20 | from contextlib import AbstractContextManager, contextmanager, nullcontext |
21 | 21 | from functools import partial |
22 | | -from typing import Any, Iterator, Optional, TypeVar, cast |
| 22 | +from typing import Any, Iterator, Optional, TypedDict, TypeVar, cast |
23 | 23 |
|
24 | 24 | import ray |
25 | 25 | import torch |
|
145 | 145 | TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) |
146 | 146 |
|
147 | 147 |
|
| 148 | +class MegatronGenerationConfig(TypedDict): |
| 149 | + # Total GPU memory (in GB) allocated for KV cache buffers |
| 150 | + buffer_size_gb: int |
| 151 | + # Fraction of buffer reserved for guaranteed active requests |
| 152 | + buffer_guaranteed_fraction: float |
| 153 | + # Number of CUDA graphs to pre-compile for different batch sizes |
| 154 | + num_cuda_graphs: int |
| 155 | + # Size of each KV cache block in tokens (affects memory granularity) |
| 156 | + block_size_tokens: int |
| 157 | + # Enable CUDA graphs for prefill/context processing |
| 158 | + use_cuda_graphs_for_non_decode_steps: bool |
| 159 | + # Split long prefills into chunks for better memory management |
| 160 | + enable_chunked_prefill: bool |
| 161 | + # Unified memory usage level (0=disabled, higher values enable more aggressive paging) |
| 162 | + unified_memory_level: int |
| 163 | + # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens. |
| 164 | + # Can cause OOM if set too high so should be tuned with buffer_size_gb if OOMing. If set too |
| 165 | + # low, then will only do 512 tokens at a time, which can be slow. |
| 166 | + max_tokens: int |
| 167 | + |
| 168 | + |
148 | 169 | def broadcast_object_across_pp_ranks(obj): |
149 | 170 | """Broadcast an object across pipeline parallel ranks. |
150 | 171 |
|
@@ -1820,22 +1841,22 @@ def generate( |
1820 | 1841 | ) |
1821 | 1842 | from megatron.core.inference.sampling_params import SamplingParams |
1822 | 1843 |
|
1823 | | - mcore_generation_config = self.cfg["generation"]["mcore_generation_config"] |
1824 | | - buffer_size_gb = mcore_generation_config.get("buffer_size_gb", 20) |
1825 | | - |
1826 | | - num_cuda_graphs = mcore_generation_config.get("num_cuda_graphs", 16) |
1827 | | - block_size_tokens = mcore_generation_config.get("block_size_tokens", 256) |
1828 | | - use_cuda_graphs_for_non_decode_steps = mcore_generation_config.get( |
1829 | | - "use_cuda_graphs_for_non_decode_steps", True |
1830 | | - ) |
1831 | | - enable_chunked_prefill = mcore_generation_config.get( |
1832 | | - "enable_chunked_prefill", True |
| 1844 | + mcore_generation_config = cast( |
| 1845 | + MegatronGenerationConfig, self.cfg["generation"]["mcore_generation_config"] |
1833 | 1846 | ) |
1834 | | - unified_memory_level = mcore_generation_config.get("unified_memory_level", 0) |
1835 | | - buffer_guaranteed_fraction = mcore_generation_config.get( |
1836 | | - "buffer_guaranteed_fraction", 0.1 |
1837 | | - ) |
1838 | | - max_tokens = mcore_generation_config.get("max_tokens", 16384) |
| 1847 | + buffer_size_gb = mcore_generation_config["buffer_size_gb"] |
| 1848 | + |
| 1849 | + num_cuda_graphs = mcore_generation_config["num_cuda_graphs"] |
| 1850 | + block_size_tokens = mcore_generation_config["block_size_tokens"] |
| 1851 | + use_cuda_graphs_for_non_decode_steps = mcore_generation_config[ |
| 1852 | + "use_cuda_graphs_for_non_decode_steps" |
| 1853 | + ] |
| 1854 | + enable_chunked_prefill = mcore_generation_config["enable_chunked_prefill"] |
| 1855 | + unified_memory_level = mcore_generation_config["unified_memory_level"] |
| 1856 | + buffer_guaranteed_fraction = mcore_generation_config[ |
| 1857 | + "buffer_guaranteed_fraction" |
| 1858 | + ] |
| 1859 | + max_tokens = mcore_generation_config["max_tokens"] |
1839 | 1860 |
|
1840 | 1861 | model_config = self.model.config |
1841 | 1862 | model_config.cuda_graph_impl = "local" |
|
0 commit comments