@@ -558,14 +558,14 @@ def _stream_generate(
558558 temperature = gen_params .get (
559559 "temperature" , DEFAULT_TEMPERATURE )
560560 top_p = gen_params .get ("top_p" , DEFAULT_TOP_P )
561- top_k = gen_params .get ("top_k" , 40 ) # Default to 40 for faster generation
561+ top_k = gen_params .get ("top_k" , 40 ) # Default to 40 for better quality
562562 repetition_penalty = gen_params .get ("repetition_penalty" , 1.1 )
563563 else :
564564 # Use provided individual parameters or defaults
565565 max_length = max_length or min (DEFAULT_MAX_LENGTH , 512 ) # Limit default max_length
566- temperature = temperature or 0.7 # Lower temperature for faster generation
566+ temperature = temperature or 0.7 # Use same temperature as non-streaming
567567 top_p = top_p or DEFAULT_TOP_P
568- top_k = 40 # Default to 40 for faster generation
568+ top_k = 40 # Default to 40 for better quality
569569 repetition_penalty = 1.1
570570
571571 # Get the actual device of the model
@@ -582,15 +582,21 @@ def _stream_generate(
582582 attention_mask = inputs ["attention_mask" ]
583583
584584 # Generate fewer tokens at once for more responsive streaming
585- # Using smaller chunks makes it appear more interactive
586- tokens_to_generate_per_step = 3 # Reduced from 8 to 3 for more responsive streaming
585+ # Using smaller chunks makes it appear more interactive while maintaining quality
586+ tokens_to_generate_per_step = 2 # Reduced from 3 to 2 for better quality control
587+
588+ # Track generated text for quality control
589+ generated_text = ""
590+
591+ # Define stop sequences for proper termination
592+ stop_sequences = ["</s>" , "<|endoftext|>" , "<|im_end|>" , "<|assistant|>" ]
587593
588594 with torch .no_grad ():
589595 for step in range (0 , max_length , tokens_to_generate_per_step ):
590596 # Calculate how many tokens to generate in this step
591597 current_tokens_to_generate = min (tokens_to_generate_per_step , max_length - step )
592598
593- # Generate parameters
599+ # Generate parameters - use the same high-quality parameters as non-streaming
594600 generate_params = {
595601 "input_ids" : input_ids ,
596602 "attention_mask" : attention_mask ,
@@ -601,7 +607,6 @@ def _stream_generate(
601607 "do_sample" : True ,
602608 "pad_token_id" : self .tokenizer .eos_token_id ,
603609 "repetition_penalty" : repetition_penalty ,
604- # Remove early_stopping to fix the warning
605610 "num_beams" : 1 # Explicitly set to 1 to avoid warnings
606611 }
607612
@@ -623,9 +628,37 @@ def _stream_generate(
623628 if not new_text or new_text .isspace ():
624629 break
625630
631+ # Add to generated text for quality control
632+ generated_text += new_text
633+
634+ # Check for stop sequences
635+ should_stop = False
636+ for stop_seq in stop_sequences :
637+ if stop_seq in generated_text :
638+ # We've reached a stop sequence, stop generation
639+ should_stop = True
640+ break
641+
642+ # Check for repetition (a sign of poor quality)
643+ if len (generated_text ) > 50 :
644+ # Check for repeating patterns of 10+ characters
645+ last_50_chars = generated_text [- 50 :]
646+ for pattern_len in range (10 , 20 ):
647+ if pattern_len < len (last_50_chars ) // 2 :
648+ pattern = last_50_chars [- pattern_len :]
649+ if pattern in last_50_chars [:- pattern_len ]:
650+ # Detected repetition, stop generation
651+ logger .warning ("Detected repetition in streaming generation, stopping" )
652+ should_stop = True
653+ break
654+
626655 # Yield the new text
627656 yield new_text
628657
658+ # Stop if needed
659+ if should_stop :
660+ break
661+
629662 # Update input_ids and attention_mask for next iteration
630663 input_ids = outputs
631664 attention_mask = torch .ones_like (input_ids )
@@ -666,6 +699,7 @@ def _stream_generate(
666699 if not new_text or new_text .isspace ():
667700 break
668701
702+ generated_text += new_text
669703 yield new_text
670704
671705 input_ids = outputs
@@ -699,6 +733,24 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
699733 # Get model-specific generation parameters
700734 from .config import get_model_generation_params
701735 gen_params = get_model_generation_params (self .current_model )
736+
737+ # Set optimized defaults for streaming that match non-streaming quality
738+ # Use the same parameters as non-streaming for consistency
739+ if not kwargs .get ("max_length" ) and not kwargs .get ("max_new_tokens" ):
740+ # Use a reasonable default max_length
741+ gen_params ["max_length" ] = min (gen_params .get ("max_length" , DEFAULT_MAX_LENGTH ), 512 )
742+
743+ if not kwargs .get ("temperature" ):
744+ # Use the same temperature as non-streaming
745+ gen_params ["temperature" ] = min (gen_params .get ("temperature" , DEFAULT_TEMPERATURE ), 0.7 )
746+
747+ if not kwargs .get ("top_k" ):
748+ # Add top_k for better quality
749+ gen_params ["top_k" ] = 40
750+
751+ if not kwargs .get ("repetition_penalty" ):
752+ # Add repetition penalty to avoid loops
753+ gen_params ["repetition_penalty" ] = 1.1
702754
703755 # Update with provided kwargs
704756 for key , value in kwargs .items ():
@@ -718,10 +770,56 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
718770 for key in inputs :
719771 inputs [key ] = inputs [key ].to (model_device )
720772
721- # Now stream tokens using the prepared inputs and parameters
722- for token in self ._stream_generate (inputs , gen_params = gen_params ):
773+ # Check if we need to clear CUDA cache before generation
774+ if torch .cuda .is_available ():
775+ current_mem = torch .cuda .memory_allocated () / (1024 * 1024 * 1024 ) # GB
776+ total_mem = torch .cuda .get_device_properties (0 ).total_memory / (1024 * 1024 * 1024 ) # GB
777+ if current_mem > 0.8 * total_mem : # If using >80% of GPU memory
778+ # Clear cache to avoid OOM
779+ torch .cuda .empty_cache ()
780+ logger .info ("Cleared CUDA cache before streaming generation to avoid out of memory error" )
781+
782+ # Create a custom stream generator with improved quality
783+ async def improved_stream_generator ():
784+ # Use the same stopping conditions as non-streaming
785+ stop_sequences = ["</s>" , "<|endoftext|>" , "<|im_end|>" , "<|assistant|>" ]
786+ accumulated_text = ""
787+
788+ # Use a generator that produces high-quality chunks
789+ try :
790+ for token_chunk in self ._stream_generate (inputs , gen_params = gen_params ):
791+ accumulated_text += token_chunk
792+
793+ # Check for stop sequences
794+ should_stop = False
795+ for stop_seq in stop_sequences :
796+ if stop_seq in accumulated_text :
797+ # Truncate at stop sequence
798+ accumulated_text = accumulated_text .split (stop_seq )[0 ]
799+ should_stop = True
800+ break
801+
802+ # Yield the token chunk
803+ yield token_chunk
804+
805+ # Stop if we've reached a stop sequence
806+ if should_stop :
807+ break
808+
809+ # Also stop if we've generated too much text (safety measure)
810+ if len (accumulated_text ) > gen_params .get ("max_length" , 512 ) * 4 : # Character estimate
811+ logger .warning ("Stream generation exceeded maximum length - stopping" )
812+ break
813+
814+ await asyncio .sleep (0 )
815+ except Exception as e :
816+ logger .error (f"Error in stream generation: { str (e )} " )
817+ # Don't propagate the error to avoid breaking the stream
818+ # Just stop generating
819+
820+ # Use the improved generator
821+ async for token in improved_stream_generator ():
723822 yield token
724- await asyncio .sleep (0 )
725823
726824 def get_model_info (self ) -> Dict [str , Any ]:
727825 """Get information about the currently loaded model"""
0 commit comments