2929
3030
3131def build_sampling_params (
32- request : Dict [str , Any ], default_sampling_params : Dict [str , Any ]
32+ request : Dict [str , Any ], default_sampling_params : Dict [str , Any ], model_max_len : int | None = None ,
3333) -> SamplingParams :
3434 """
3535 Build SamplingParams from a PreprocessedRequest.
@@ -56,6 +56,21 @@ def build_sampling_params(
5656 if key == "stop" :
5757 continue
5858 setattr (sampling_params , key , value )
59+
60+ # If max_tokens wasn't provided (None or missing), compute a dynamic default
61+ try :
62+ provided_max_tokens = request .get ("stop_conditions" , {}).get ("max_tokens" , None )
63+ token_ids = request .get ("token_ids" , [])
64+ input_length = len (token_ids )
65+ if (
66+ model_max_len is not None
67+ and (provided_max_tokens is None )
68+ ):
69+ # Ensure at least 1 token generation by default when possible
70+ dynamic_default = max (1 , model_max_len - input_length )
71+ sampling_params .max_tokens = dynamic_default
72+ except Exception :
73+ pass
5974
6075 return sampling_params
6176
@@ -65,14 +80,16 @@ class BaseWorkerHandler(ABC):
6580 Request handler for the generate and clear_kv_blocks endpoints.
6681 """
6782
68- def __init__ (self , runtime , component , engine , default_sampling_params ):
83+ def __init__ (self , runtime , component , engine , default_sampling_params , model_max_len : int | None = None ):
6984 self .runtime = runtime
7085 self .component = component
7186 self .engine_client = engine
7287 self .default_sampling_params = default_sampling_params
7388 self .kv_publishers : list [ZmqKvEventPublisher ] | None = None
7489 self .engine_monitor = VllmEngineMonitor (runtime , engine )
7590 self .image_loader = ImageLoader ()
91+ self .model_max_len = model_max_len
92+
7693
7794 @abstractmethod
7895 async def generate (self , request , context ) -> AsyncGenerator [dict , None ]:
@@ -212,8 +229,9 @@ def __init__(
212229 component ,
213230 engine ,
214231 default_sampling_params ,
232+ model_max_len : int | None = None ,
215233 ):
216- super ().__init__ (runtime , component , engine , default_sampling_params )
234+ super ().__init__ (runtime , component , engine , default_sampling_params , model_max_len )
217235
218236 async def generate (self , request , context ):
219237 # Use context ID for request tracking and correlation
@@ -228,7 +246,7 @@ async def generate(self, request, context):
228246 )
229247
230248 # Build sampling params from request
231- sampling_params = build_sampling_params (request , self .default_sampling_params )
249+ sampling_params = build_sampling_params (request , self .default_sampling_params , self . model_max_len )
232250
233251 # Extract disaggregated_params from request (set by prefill router in Rust frontend)
234252 disaggregated_params = request .get ("disaggregated_params" )
@@ -259,8 +277,8 @@ async def generate(self, request, context):
259277
260278
261279class PrefillWorkerHandler (BaseWorkerHandler ):
262- def __init__ (self , runtime , component , engine , default_sampling_params ):
263- super ().__init__ (runtime , component , engine , default_sampling_params )
280+ def __init__ (self , runtime , component , engine , default_sampling_params , model_max_len : int | None = None ):
281+ super ().__init__ (runtime , component , engine , default_sampling_params , model_max_len )
264282
265283 async def generate (self , request , context ):
266284 # Use context ID for request tracking and correlation with decode phase
@@ -276,7 +294,7 @@ async def generate(self, request, context):
276294 )
277295
278296 # Build sampling params from request using shared utility
279- sampling_params = build_sampling_params (request , self .default_sampling_params )
297+ sampling_params = build_sampling_params (request , self .default_sampling_params , self . model_max_len )
280298
281299 # Configure for prefill-only mode with remote decode
282300 if sampling_params .extra_args is None :
0 commit comments