@@ -35,6 +35,7 @@ class GenerationRequest(BaseModel):
3535 top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
3636 top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
3737 repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
38+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
3839 system_prompt : Optional [str ] = Field (default = DEFAULT_SYSTEM_INSTRUCTIONS )
3940 stream : bool = Field (default = False )
4041
@@ -47,6 +48,7 @@ class BatchGenerationRequest(BaseModel):
4748 top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
4849 top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
4950 repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
51+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
5052 system_prompt : Optional [str ] = Field (default = DEFAULT_SYSTEM_INSTRUCTIONS )
5153
5254
@@ -64,6 +66,7 @@ class ChatRequest(BaseModel):
6466 top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
6567 top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
6668 repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
69+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
6770 stream : bool = Field (default = False )
6871
6972
@@ -129,7 +132,7 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
129132 # Return a streaming response
130133 return StreamingResponse (
131134 generate_stream (request .prompt , request .max_tokens , request .temperature ,
132- request .top_p , request .system_prompt ),
135+ request .top_p , request .system_prompt , request . max_time ),
133136 media_type = "text/event-stream"
134137 )
135138
@@ -144,7 +147,8 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
144147 "top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
145148 "top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
146149 "repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
147- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
150+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
151+ "max_time" : request .max_time # Pass max_time parameter
148152 }
149153
150154 # Merge model-specific params with request params
@@ -212,7 +216,7 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
212216 # If streaming is requested, return a streaming response
213217 if request .stream :
214218 return StreamingResponse (
215- stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p ),
219+ stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p , request . max_time ),
216220 media_type = "text/event-stream"
217221 )
218222
@@ -227,7 +231,8 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
227231 "top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
228232 "top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
229233 "repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
230- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
234+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
235+ "max_time" : request .max_time # Pass max_time parameter
231236 }
232237
233238 # Merge model-specific params with request params
@@ -292,7 +297,8 @@ async def generate_stream(
292297 max_tokens : int ,
293298 temperature : float ,
294299 top_p : float ,
295- system_prompt : Optional [str ]
300+ system_prompt : Optional [str ],
301+ max_time : Optional [float ] = None
296302) -> AsyncGenerator [str , None ]:
297303 """
298304 Generate text in a streaming fashion and return as server-sent events
@@ -309,7 +315,8 @@ async def generate_stream(
309315 "top_p" : top_p ,
310316 "top_k" : 80 , # Optimized top_k for high-quality streaming
311317 "repetition_penalty" : 1.15 , # Optimized repetition_penalty for high-quality streaming
312- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
318+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
319+ "max_time" : max_time # Pass max_time parameter
313320 }
314321
315322 # Merge model-specific params with request params
@@ -361,7 +368,8 @@ async def stream_chat(
361368 formatted_prompt : str ,
362369 max_tokens : int ,
363370 temperature : float ,
364- top_p : float
371+ top_p : float ,
372+ max_time : Optional [float ] = None
365373) -> AsyncGenerator [str , None ]:
366374 """
367375 Stream chat completion responses as server-sent events
@@ -378,7 +386,8 @@ async def stream_chat(
378386 "top_p" : top_p ,
379387 "top_k" : 80 , # Optimized top_k for high-quality streaming
380388 "repetition_penalty" : 1.15 , # Optimized repetition_penalty for high-quality streaming
381- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
389+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
390+ "max_time" : max_time # Pass max_time parameter
382391 }
383392
384393 # Merge model-specific params with request params
@@ -438,7 +447,8 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
438447 "top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
439448 "top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
440449 "repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
441- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
450+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
451+ "max_time" : request .max_time # Pass max_time parameter
442452 }
443453
444454 # Merge model-specific params with request params
0 commit comments