2929
3030
3131def build_sampling_params (
32- request : Dict [str , Any ], default_sampling_params : Dict [str , Any ]
32+ request : Dict [str , Any ],
33+ default_sampling_params : Dict [str , Any ],
34+ model_max_len : int | None = None ,
3335) -> SamplingParams :
3436 """
3537 Build SamplingParams from a PreprocessedRequest.
@@ -57,6 +59,18 @@ def build_sampling_params(
5759 continue
5860 setattr (sampling_params , key , value )
5961
62+ # If max_tokens wasn't provided (None or missing), compute a dynamic default
63+ try :
64+ provided_max_tokens = request .get ("stop_conditions" , {}).get ("max_tokens" , None )
65+ token_ids = request .get ("token_ids" , [])
66+ input_length = len (token_ids )
67+ if model_max_len is not None and (provided_max_tokens is None ):
68+ # Ensure at least 1 token generation by default when possible
69+ dynamic_default = max (1 , model_max_len - input_length )
70+ sampling_params .max_tokens = dynamic_default
71+ except Exception :
72+ pass
73+
6074 return sampling_params
6175
6276
@@ -65,14 +79,22 @@ class BaseWorkerHandler(ABC):
6579 Request handler for the generate and clear_kv_blocks endpoints.
6680 """
6781
68- def __init__ (self , runtime , component , engine , default_sampling_params ):
82+ def __init__ (
83+ self ,
84+ runtime ,
85+ component ,
86+ engine ,
87+ default_sampling_params ,
88+ model_max_len : int | None = None ,
89+ ):
6990 self .runtime = runtime
7091 self .component = component
7192 self .engine_client = engine
7293 self .default_sampling_params = default_sampling_params
7394 self .kv_publishers : list [ZmqKvEventPublisher ] | None = None
7495 self .engine_monitor = VllmEngineMonitor (runtime , engine )
7596 self .image_loader = ImageLoader ()
97+ self .model_max_len = model_max_len
7698
7799 @abstractmethod
78100 async def generate (self , request , context ) -> AsyncGenerator [dict , None ]:
@@ -212,8 +234,11 @@ def __init__(
212234 component ,
213235 engine ,
214236 default_sampling_params ,
237+ model_max_len : int | None = None ,
215238 ):
216- super ().__init__ (runtime , component , engine , default_sampling_params )
239+ super ().__init__ (
240+ runtime , component , engine , default_sampling_params , model_max_len
241+ )
217242
218243 async def generate (self , request , context ):
219244 # Use context ID for request tracking and correlation
@@ -228,7 +253,9 @@ async def generate(self, request, context):
228253 )
229254
230255 # Build sampling params from request
231- sampling_params = build_sampling_params (request , self .default_sampling_params )
256+ sampling_params = build_sampling_params (
257+ request , self .default_sampling_params , self .model_max_len
258+ )
232259
233260 # Extract disaggregated_params from request (set by prefill router in Rust frontend)
234261 disaggregated_params = request .get ("disaggregated_params" )
@@ -259,8 +286,17 @@ async def generate(self, request, context):
259286
260287
261288class PrefillWorkerHandler (BaseWorkerHandler ):
262- def __init__ (self , runtime , component , engine , default_sampling_params ):
263- super ().__init__ (runtime , component , engine , default_sampling_params )
289+ def __init__ (
290+ self ,
291+ runtime ,
292+ component ,
293+ engine ,
294+ default_sampling_params ,
295+ model_max_len : int | None = None ,
296+ ):
297+ super ().__init__ (
298+ runtime , component , engine , default_sampling_params , model_max_len
299+ )
264300
265301 async def generate (self , request , context ):
266302 # Use context ID for request tracking and correlation with decode phase
@@ -276,7 +312,9 @@ async def generate(self, request, context):
276312 )
277313
278314 # Build sampling params from request using shared utility
279- sampling_params = build_sampling_params (request , self .default_sampling_params )
315+ sampling_params = build_sampling_params (
316+ request , self .default_sampling_params , self .model_max_len
317+ )
280318
281319 # Configure for prefill-only mode with remote decode
282320 if sampling_params .extra_args is None :
0 commit comments