@@ -280,6 +280,7 @@ async def generate(
280280 prompt : str ,
281281 stream : bool = False ,
282282 max_length : Optional [int ] = None ,
283+ max_new_tokens : Optional [int ] = None ,
283284 temperature : Optional [float ] = None ,
284285 top_p : Optional [float ] = None ,
285286 top_k : Optional [int ] = None ,
@@ -307,6 +308,10 @@ async def generate(
307308 from .config import get_model_generation_params
308309 gen_params = get_model_generation_params (self .current_model )
309310
311+ # Handle max_new_tokens parameter (map to max_length)
312+ if max_new_tokens is not None :
313+ max_length = max_new_tokens
314+
310315 # Override with user-provided parameters if specified
311316 if max_length is not None :
312317 try :
@@ -423,8 +428,40 @@ def _stream_generate(
423428 logger .error (f"Streaming generation failed: { str (e )} " )
424429 raise HTTPException (status_code = 500 , detail = f"Streaming generation failed: { str (e )} " )
425430
426- async def async_stream_generate (self , inputs : Dict [str , torch .Tensor ], gen_params : Dict [str , Any ]):
427- """Convert the synchronous stream generator to an async generator."""
431+ async def async_stream_generate (self , inputs : Dict [str , torch .Tensor ] = None , gen_params : Dict [str , Any ] = None , prompt : str = None , system_prompt : Optional [str ] = None , ** kwargs ):
432+ """Convert the synchronous stream generator to an async generator.
433+
434+ This can be called either with:
435+ 1. inputs and gen_params directly (internal use)
436+ 2. prompt, system_prompt and other kwargs (from generate_stream adapter)
437+ """
438+ # If called with prompt, prepare inputs and parameters
439+ if prompt is not None :
440+ # Get appropriate system instructions
441+ from .config import system_instructions
442+ instructions = str (system_instructions .get_instructions (self .current_model )) if not system_prompt else str (system_prompt )
443+
444+ # Format prompt with system instructions
445+ formatted_prompt = f"""<|system|>{ instructions } </|system|>\n <|user|>{ prompt } </|user|>\n <|assistant|>"""
446+
447+ # Get model-specific generation parameters
448+ from .config import get_model_generation_params
449+ gen_params = get_model_generation_params (self .current_model )
450+
451+ # Update with provided kwargs
452+ for key , value in kwargs .items ():
453+ if key in ["max_length" , "temperature" , "top_p" , "top_k" , "repetition_penalty" ]:
454+ gen_params [key ] = value
455+ elif key == "max_new_tokens" :
456+ # Handle the max_new_tokens parameter by mapping to max_length
457+ gen_params ["max_length" ] = value
458+
459+ # Tokenize the prompt
460+ inputs = self .tokenizer (formatted_prompt , return_tensors = "pt" )
461+ for key in inputs :
462+ inputs [key ] = inputs [key ].to (self .device )
463+
464+ # Now stream tokens using the prepared inputs and parameters
428465 for token in self ._stream_generate (inputs , gen_params = gen_params ):
429466 yield token
430467 await asyncio .sleep (0 )
@@ -564,6 +601,11 @@ async def generate_text(self, prompt: str, system_prompt: Optional[str] = None,
564601 """
565602 # Make sure we're not streaming when generating text
566603 kwargs ["stream" ] = False
604+
605+ # Handle max_new_tokens parameter by mapping to max_length if needed
606+ if "max_new_tokens" in kwargs and "max_length" not in kwargs :
607+ kwargs ["max_length" ] = kwargs .pop ("max_new_tokens" )
608+
567609 # Directly await the generate method to return the string result
568610 return await self .generate (prompt = prompt , system_instructions = system_prompt , ** kwargs )
569611
@@ -572,7 +614,14 @@ async def generate_stream(self, prompt: str, system_prompt: Optional[str] = None
572614 Calls the async_stream_generate method with proper parameters."""
573615 # Ensure streaming is enabled
574616 kwargs ["stream" ] = True
575- return self .async_stream_generate (prompt = prompt , system_prompt = system_prompt , ** kwargs )
617+
618+ # Handle max_new_tokens parameter by mapping to max_length
619+ if "max_new_tokens" in kwargs and "max_length" not in kwargs :
620+ kwargs ["max_length" ] = kwargs .pop ("max_new_tokens" )
621+
622+ # Call async_stream_generate with the prompt and parameters
623+ async for token in self .async_stream_generate (prompt = prompt , system_prompt = system_prompt , ** kwargs ):
624+ yield token
576625
577626 def is_model_loaded (self , model_id : str ) -> bool :
578627 """Check if a specific model is loaded.
0 commit comments