|
26 | 26 | import argparse |
27 | 27 | import logging |
28 | 28 | from collections.abc import Iterable |
| 29 | +from typing import Optional |
29 | 30 |
|
30 | 31 | import apache_beam as beam |
31 | 32 | from apache_beam.ml.inference.base import PredictionResult |
|
37 | 38 | from apache_beam.options.pipeline_options import SetupOptions |
38 | 39 | from apache_beam.runners.runner import PipelineResult |
39 | 40 |
|
| 41 | +# Defaults avoid CUDA OOM on ~16GB GPUs (e.g. NVIDIA T4) with vLLM V1: the engine |
| 42 | +# warms the sampler with many dummy sequences unless max_num_seqs is reduced, and |
| 43 | +# the default gpu_memory_utilization can leave no free VRAM for that step. |
| 44 | +_DEFAULT_VLLM_MAX_NUM_SEQS = 32 |
| 45 | +_DEFAULT_VLLM_GPU_MEMORY_UTILIZATION = 0.72 |
| 46 | + |
40 | 47 | COMPLETION_EXAMPLES = [ |
41 | 48 | "Hello, my name is", |
42 | 49 | "The president of the United States is", |
@@ -112,33 +119,72 @@ def parse_known_args(argv): |
112 | 119 | required=False, |
113 | 120 | default=None, |
114 | 121 | help='Chat template to use for chat example.') |
| 122 | + parser.add_argument( |
| 123 | + '--vllm_max_num_seqs', |
| 124 | + dest='vllm_max_num_seqs', |
| 125 | + type=int, |
| 126 | + default=_DEFAULT_VLLM_MAX_NUM_SEQS, |
| 127 | + help=( |
| 128 | + 'Passed to the vLLM OpenAI server as --max-num-seqs. ' |
| 129 | + 'Lower values use less GPU memory during startup and inference; ' |
| 130 | + 'required for many ~16GB GPUs (see --vllm_gpu_memory_utilization).')) |
| 131 | + parser.add_argument( |
| 132 | + '--vllm_gpu_memory_utilization', |
| 133 | + dest='vllm_gpu_memory_utilization', |
| 134 | + type=float, |
| 135 | + default=_DEFAULT_VLLM_GPU_MEMORY_UTILIZATION, |
| 136 | + help=( |
| 137 | + 'Passed to the vLLM OpenAI server as --gpu-memory-utilization ' |
| 138 | + '(fraction of total GPU memory for KV cache). Lower this if the ' |
| 139 | + 'engine fails to start with CUDA out of memory.')) |
115 | 140 | return parser.parse_known_args(argv) |
116 | 141 |
|
117 | 142 |
|
| 143 | +def build_vllm_server_kwargs(known_args) -> dict[str, str]: |
| 144 | + """Returns CLI flags for ``VLLMCompletionsModelHandler(..., vllm_server_kwargs=...)``.""" |
| 145 | + return { |
| 146 | + 'max-num-seqs': str(known_args.vllm_max_num_seqs), |
| 147 | + 'gpu-memory-utilization': str(known_args.vllm_gpu_memory_utilization), |
| 148 | + } |
| 149 | + |
| 150 | + |
118 | 151 | class PostProcessor(beam.DoFn): |
119 | 152 | def process(self, element: PredictionResult) -> Iterable[str]: |
120 | 153 | yield str(element.example) + ": " + str(element.inference) |
121 | 154 |
|
122 | 155 |
|
123 | 156 | def run( |
124 | | - argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: |
| 157 | + argv=None, |
| 158 | + save_main_session=True, |
| 159 | + test_pipeline=None, |
| 160 | + vllm_server_kwargs: Optional[dict[str, str]] = None) -> PipelineResult: |
125 | 161 | """ |
126 | 162 | Args: |
127 | 163 | argv: Command line arguments defined for this example. |
128 | 164 | save_main_session: Used for internal testing. |
129 | 165 | test_pipeline: Used for internal testing. |
| 166 | + vllm_server_kwargs: Optional override for vLLM server options. When None, |
| 167 | + options are taken from argv (``--vllm_max_num_seqs``, |
| 168 | + ``--vllm_gpu_memory_utilization``). When set, argv tuning flags for the |
| 169 | + server are ignored in favor of this dict (e.g. for programmatic use). |
130 | 170 | """ |
131 | 171 | known_args, pipeline_args = parse_known_args(argv) |
132 | 172 | pipeline_options = PipelineOptions(pipeline_args) |
133 | 173 | pipeline_options.view_as(SetupOptions).save_main_session = save_main_session |
134 | 174 |
|
135 | | - model_handler = VLLMCompletionsModelHandler(model_name=known_args.model) |
| 175 | + effective_vllm_kwargs = ( |
| 176 | + vllm_server_kwargs if vllm_server_kwargs is not None else |
| 177 | + build_vllm_server_kwargs(known_args)) |
| 178 | + |
| 179 | + model_handler = VLLMCompletionsModelHandler( |
| 180 | + model_name=known_args.model, vllm_server_kwargs=effective_vllm_kwargs) |
136 | 181 | input_examples = COMPLETION_EXAMPLES |
137 | 182 |
|
138 | 183 | if known_args.chat: |
139 | 184 | model_handler = VLLMChatModelHandler( |
140 | 185 | model_name=known_args.model, |
141 | | - chat_template_path=known_args.chat_template) |
| 186 | + chat_template_path=known_args.chat_template, |
| 187 | + vllm_server_kwargs=dict(effective_vllm_kwargs)) |
142 | 188 | input_examples = CHAT_EXAMPLES |
143 | 189 |
|
144 | 190 | pipeline = test_pipeline |
|
0 commit comments