Skip to content

Commit 5cb6890

Browse files
committed
support a dynamic default max_tokens for VLLM backend
Signed-off-by: bin <[email protected]>
1 parent aebf168 commit 5cb6890

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030

3131
def build_sampling_params(
32-
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
32+
request: Dict[str, Any],
33+
default_sampling_params: Dict[str, Any],
34+
model_max_len: int | None = None,
3335
) -> SamplingParams:
3436
"""
3537
Build SamplingParams from a PreprocessedRequest.
@@ -57,6 +59,18 @@ def build_sampling_params(
5759
continue
5860
setattr(sampling_params, key, value)
5961

62+
# If max_tokens wasn't provided (None or missing), compute a dynamic default
63+
try:
64+
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
65+
token_ids = request.get("token_ids", [])
66+
input_length = len(token_ids)
67+
if model_max_len is not None and (provided_max_tokens is None):
68+
# Ensure at least 1 token generation by default when possible
69+
dynamic_default = max(1, model_max_len - input_length)
70+
sampling_params.max_tokens = dynamic_default
71+
except Exception:
72+
pass
73+
6074
return sampling_params
6175

6276

@@ -65,14 +79,22 @@ class BaseWorkerHandler(ABC):
6579
Request handler for the generate and clear_kv_blocks endpoints.
6680
"""
6781

68-
def __init__(self, runtime, component, engine, default_sampling_params):
82+
def __init__(
83+
self,
84+
runtime,
85+
component,
86+
engine,
87+
default_sampling_params,
88+
model_max_len: int | None = None,
89+
):
6990
self.runtime = runtime
7091
self.component = component
7192
self.engine_client = engine
7293
self.default_sampling_params = default_sampling_params
7394
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
7495
self.engine_monitor = VllmEngineMonitor(runtime, engine)
7596
self.image_loader = ImageLoader()
97+
self.model_max_len = model_max_len
7698

7799
@abstractmethod
78100
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
@@ -212,8 +234,11 @@ def __init__(
212234
component,
213235
engine,
214236
default_sampling_params,
237+
model_max_len: int | None = None,
215238
):
216-
super().__init__(runtime, component, engine, default_sampling_params)
239+
super().__init__(
240+
runtime, component, engine, default_sampling_params, model_max_len
241+
)
217242

218243
async def generate(self, request, context):
219244
# Use context ID for request tracking and correlation
@@ -228,7 +253,9 @@ async def generate(self, request, context):
228253
)
229254

230255
# Build sampling params from request
231-
sampling_params = build_sampling_params(request, self.default_sampling_params)
256+
sampling_params = build_sampling_params(
257+
request, self.default_sampling_params, self.model_max_len
258+
)
232259

233260
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
234261
disaggregated_params = request.get("disaggregated_params")
@@ -259,8 +286,17 @@ async def generate(self, request, context):
259286

260287

261288
class PrefillWorkerHandler(BaseWorkerHandler):
262-
def __init__(self, runtime, component, engine, default_sampling_params):
263-
super().__init__(runtime, component, engine, default_sampling_params)
289+
def __init__(
290+
self,
291+
runtime,
292+
component,
293+
engine,
294+
default_sampling_params,
295+
model_max_len: int | None = None,
296+
):
297+
super().__init__(
298+
runtime, component, engine, default_sampling_params, model_max_len
299+
)
264300

265301
async def generate(self, request, context):
266302
# Use context ID for request tracking and correlation with decode phase
@@ -276,7 +312,9 @@ async def generate(self, request, context):
276312
)
277313

278314
# Build sampling params from request using shared utility
279-
sampling_params = build_sampling_params(request, self.default_sampling_params)
315+
sampling_params = build_sampling_params(
316+
request, self.default_sampling_params, self.model_max_len
317+
)
280318

281319
# Configure for prefill-only mode with remote decode
282320
if sampling_params.extra_args is None:

components/src/dynamo/vllm/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
317317
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
318318

319319
handler = PrefillWorkerHandler(
320-
runtime, component, engine_client, default_sampling_params
320+
runtime,
321+
component,
322+
engine_client,
323+
default_sampling_params,
324+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
321325
)
322326

323327
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
@@ -424,6 +428,7 @@ async def init(runtime: DistributedRuntime, config: Config):
424428
component,
425429
engine_client,
426430
default_sampling_params,
431+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
427432
)
428433

429434
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)

0 commit comments

Comments
 (0)