2222
2323
2424def build_sampling_params (
25- request : Dict [str , Any ], default_sampling_params : Dict [str , Any ]
25+ request : Dict [str , Any ], default_sampling_params : Dict [str , Any ], model_max_len : int | None = None ,
2626) -> SamplingParams :
2727 """
2828 Build SamplingParams from a PreprocessedRequest.
@@ -49,6 +49,21 @@ def build_sampling_params(
4949 if key == "stop" :
5050 continue
5151 setattr (sampling_params , key , value )
52+
53+ # If max_tokens wasn't provided (None or missing), compute a dynamic default
54+ try :
55+ provided_max_tokens = request .get ("stop_conditions" , {}).get ("max_tokens" , None )
56+ token_ids = request .get ("token_ids" , [])
57+ input_length = len (token_ids )
58+ if (
59+ model_max_len is not None
60+ and (provided_max_tokens is None )
61+ ):
62+ # Ensure at least 1 token generation by default when possible
63+ dynamic_default = max (1 , model_max_len - input_length )
64+ sampling_params .max_tokens = dynamic_default
65+ except Exception :
66+ pass
5267
5368 return sampling_params
5469
@@ -58,13 +73,15 @@ class BaseWorkerHandler(ABC):
5873 Request handler for the generate and clear_kv_blocks endpoints.
5974 """
6075
61- def __init__ (self , runtime , component , engine , default_sampling_params ):
76+ def __init__ (self , runtime , component , engine , default_sampling_params , model_max_len : int | None = None ):
6277 self .runtime = runtime
6378 self .component = component
6479 self .engine_client = engine
6580 self .default_sampling_params = default_sampling_params
6681 self .kv_publishers : list [ZmqKvEventPublisher ] | None = None
6782 self .engine_monitor = VllmEngineMonitor (runtime , engine )
83+ self .model_max_len = model_max_len
84+
6885
6986 @abstractmethod
7087 async def generate (self , request , context ) -> AsyncGenerator [dict , None ]:
@@ -160,8 +177,9 @@ def __init__(
160177 component ,
161178 engine ,
162179 default_sampling_params ,
180+ model_max_len : int | None = None ,
163181 ):
164- super ().__init__ (runtime , component , engine , default_sampling_params )
182+ super ().__init__ (runtime , component , engine , default_sampling_params , model_max_len )
165183
166184 async def generate (self , request , context ):
167185 # Use context ID for request tracking and correlation
@@ -171,7 +189,7 @@ async def generate(self, request, context):
171189 prompt = TokensPrompt (prompt_token_ids = request ["token_ids" ])
172190
173191 # Build sampling params from request
174- sampling_params = build_sampling_params (request , self .default_sampling_params )
192+ sampling_params = build_sampling_params (request , self .default_sampling_params , self . model_max_len )
175193
176194 # Extract disaggregated_params from request (set by prefill router in Rust frontend)
177195 disaggregated_params = request .get ("disaggregated_params" )
@@ -202,8 +220,8 @@ async def generate(self, request, context):
202220
203221
204222class PrefillWorkerHandler (BaseWorkerHandler ):
205- def __init__ (self , runtime , component , engine , default_sampling_params ):
206- super ().__init__ (runtime , component , engine , default_sampling_params )
223+ def __init__ (self , runtime , component , engine , default_sampling_params , model_max_len : int | None = None ):
224+ super ().__init__ (runtime , component , engine , default_sampling_params , model_max_len )
207225
208226 async def generate (self , request , context ):
209227 # Use context ID for request tracking and correlation with decode phase
@@ -214,7 +232,7 @@ async def generate(self, request, context):
214232 prompt = TokensPrompt (prompt_token_ids = token_ids )
215233
216234 # Build sampling params from request using shared utility
217- sampling_params = build_sampling_params (request , self .default_sampling_params )
235+ sampling_params = build_sampling_params (request , self .default_sampling_params , self . model_max_len )
218236
219237 # Configure for prefill-only mode with remote decode
220238 if sampling_params .extra_args is None :
0 commit comments