Skip to content

Commit 0b2e0cb

Browse files
committed
support a dynamic default max_tokens for VLLM backend
Signed-off-by: bin <[email protected]>
1 parent aebf168 commit 0b2e0cb

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
def build_sampling_params(
32-
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
32+
request: Dict[str, Any], default_sampling_params: Dict[str, Any], model_max_len: int | None = None,
3333
) -> SamplingParams:
3434
"""
3535
Build SamplingParams from a PreprocessedRequest.
@@ -56,6 +56,21 @@ def build_sampling_params(
5656
if key == "stop":
5757
continue
5858
setattr(sampling_params, key, value)
59+
60+
# If max_tokens wasn't provided (None or missing), compute a dynamic default
61+
try:
62+
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
63+
token_ids = request.get("token_ids", [])
64+
input_length = len(token_ids)
65+
if (
66+
model_max_len is not None
67+
and (provided_max_tokens is None)
68+
):
69+
# Ensure at least 1 token generation by default when possible
70+
dynamic_default = max(1, model_max_len - input_length)
71+
sampling_params.max_tokens = dynamic_default
72+
except Exception:
73+
pass
5974

6075
return sampling_params
6176

@@ -65,14 +80,16 @@ class BaseWorkerHandler(ABC):
6580
Request handler for the generate and clear_kv_blocks endpoints.
6681
"""
6782

68-
def __init__(self, runtime, component, engine, default_sampling_params):
83+
def __init__(self, runtime, component, engine, default_sampling_params, model_max_len: int | None = None):
6984
self.runtime = runtime
7085
self.component = component
7186
self.engine_client = engine
7287
self.default_sampling_params = default_sampling_params
7388
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
7489
self.engine_monitor = VllmEngineMonitor(runtime, engine)
7590
self.image_loader = ImageLoader()
91+
self.model_max_len = model_max_len
92+
7693

7794
@abstractmethod
7895
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
@@ -212,8 +229,9 @@ def __init__(
212229
component,
213230
engine,
214231
default_sampling_params,
232+
model_max_len: int | None = None,
215233
):
216-
super().__init__(runtime, component, engine, default_sampling_params)
234+
super().__init__(runtime, component, engine, default_sampling_params, model_max_len)
217235

218236
async def generate(self, request, context):
219237
# Use context ID for request tracking and correlation
@@ -228,7 +246,7 @@ async def generate(self, request, context):
228246
)
229247

230248
# Build sampling params from request
231-
sampling_params = build_sampling_params(request, self.default_sampling_params)
249+
sampling_params = build_sampling_params(request, self.default_sampling_params, self.model_max_len)
232250

233251
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
234252
disaggregated_params = request.get("disaggregated_params")
@@ -259,8 +277,8 @@ async def generate(self, request, context):
259277

260278

261279
class PrefillWorkerHandler(BaseWorkerHandler):
262-
def __init__(self, runtime, component, engine, default_sampling_params):
263-
super().__init__(runtime, component, engine, default_sampling_params)
280+
def __init__(self, runtime, component, engine, default_sampling_params, model_max_len: int | None = None):
281+
super().__init__(runtime, component, engine, default_sampling_params, model_max_len)
264282

265283
async def generate(self, request, context):
266284
# Use context ID for request tracking and correlation with decode phase
@@ -276,7 +294,7 @@ async def generate(self, request, context):
276294
)
277295

278296
# Build sampling params from request using shared utility
279-
sampling_params = build_sampling_params(request, self.default_sampling_params)
297+
sampling_params = build_sampling_params(request, self.default_sampling_params, self.model_max_len)
280298

281299
# Configure for prefill-only mode with remote decode
282300
if sampling_params.extra_args is None:

components/src/dynamo/vllm/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
317317
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
318318

319319
handler = PrefillWorkerHandler(
320-
runtime, component, engine_client, default_sampling_params
320+
runtime,
321+
component,
322+
engine_client,
323+
default_sampling_params,
324+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
321325
)
322326

323327
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
@@ -424,6 +428,7 @@ async def init(runtime: DistributedRuntime, config: Config):
424428
component,
425429
engine_client,
426430
default_sampling_params,
431+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
427432
)
428433

429434
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)

0 commit comments

Comments
 (0)