Skip to content

Commit 82f817a

Browse files
committed
support dynamic compute max_tokens for VLLM backend
1 parent 4765d88 commit 82f817a

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def build_sampling_params(
25-
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
25+
request: Dict[str, Any], default_sampling_params: Dict[str, Any], model_max_len: int | None = None,
2626
) -> SamplingParams:
2727
"""
2828
Build SamplingParams from a PreprocessedRequest.
@@ -49,6 +49,21 @@ def build_sampling_params(
4949
if key == "stop":
5050
continue
5151
setattr(sampling_params, key, value)
52+
53+
# If max_tokens wasn't provided (None or missing), compute a dynamic default
54+
try:
55+
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
56+
token_ids = request.get("token_ids", [])
57+
input_length = len(token_ids)
58+
if (
59+
model_max_len is not None
60+
and (provided_max_tokens is None)
61+
):
62+
# Ensure at least 1 token generation by default when possible
63+
dynamic_default = max(1, model_max_len - input_length)
64+
sampling_params.max_tokens = dynamic_default
65+
except Exception:
66+
pass
5267

5368
return sampling_params
5469

@@ -58,13 +73,15 @@ class BaseWorkerHandler(ABC):
5873
Request handler for the generate and clear_kv_blocks endpoints.
5974
"""
6075

61-
def __init__(self, runtime, component, engine, default_sampling_params):
76+
def __init__(self, runtime, component, engine, default_sampling_params, model_max_len: int | None = None):
6277
self.runtime = runtime
6378
self.component = component
6479
self.engine_client = engine
6580
self.default_sampling_params = default_sampling_params
6681
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
6782
self.engine_monitor = VllmEngineMonitor(runtime, engine)
83+
self.model_max_len = model_max_len
84+
6885

6986
@abstractmethod
7087
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
@@ -160,8 +177,9 @@ def __init__(
160177
component,
161178
engine,
162179
default_sampling_params,
180+
model_max_len: int | None = None,
163181
):
164-
super().__init__(runtime, component, engine, default_sampling_params)
182+
super().__init__(runtime, component, engine, default_sampling_params, model_max_len)
165183

166184
async def generate(self, request, context):
167185
# Use context ID for request tracking and correlation
@@ -171,7 +189,7 @@ async def generate(self, request, context):
171189
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
172190

173191
# Build sampling params from request
174-
sampling_params = build_sampling_params(request, self.default_sampling_params)
192+
sampling_params = build_sampling_params(request, self.default_sampling_params, self.model_max_len)
175193

176194
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
177195
disaggregated_params = request.get("disaggregated_params")
@@ -202,8 +220,8 @@ async def generate(self, request, context):
202220

203221

204222
class PrefillWorkerHandler(BaseWorkerHandler):
205-
def __init__(self, runtime, component, engine, default_sampling_params):
206-
super().__init__(runtime, component, engine, default_sampling_params)
223+
def __init__(self, runtime, component, engine, default_sampling_params, model_max_len: int | None = None):
224+
super().__init__(runtime, component, engine, default_sampling_params, model_max_len)
207225

208226
async def generate(self, request, context):
209227
# Use context ID for request tracking and correlation with decode phase
@@ -214,7 +232,7 @@ async def generate(self, request, context):
214232
prompt = TokensPrompt(prompt_token_ids=token_ids)
215233

216234
# Build sampling params from request using shared utility
217-
sampling_params = build_sampling_params(request, self.default_sampling_params)
235+
sampling_params = build_sampling_params(request, self.default_sampling_params, self.model_max_len)
218236

219237
# Configure for prefill-only mode with remote decode
220238
if sampling_params.extra_args is None:

components/src/dynamo/vllm/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
307307
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
308308

309309
handler = PrefillWorkerHandler(
310-
runtime, component, engine_client, default_sampling_params
310+
runtime,
311+
component,
312+
engine_client,
313+
default_sampling_params,
314+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
311315
)
312316

313317
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
@@ -414,6 +418,7 @@ async def init(runtime: DistributedRuntime, config: Config):
414418
component,
415419
engine_client,
416420
default_sampling_params,
421+
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
417422
)
418423

419424
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)

0 commit comments

Comments
 (0)