1414import os
1515import transformers
1616import warnings
17+ from .chain .streaming_observer import ConsoleStreamObserver , StreamingManager
1718try :
1819 from verl .protocol import DataProto
1920except ImportError :
@@ -43,6 +44,7 @@ def __init__(
4344 log_file : str = "agent" ,
4445 project_name : str = None ,
4546 run_name : str = None ,
47+ streaming : str = "console" ,
4648 ** kwargs # To pass other unused arguments
4749 ):
4850 """
@@ -68,6 +70,12 @@ def __init__(
6870 self .jinja_template = get_template (self .template ).jinja_template ()
6971 self .project_name = project_name
7072 self .run_name = run_name
73+ self .streaming_manager = StreamingManager ()
74+ if streaming == "console" :
75+ self .streaming_manager .add_observer (ConsoleStreamObserver ())
76+ else :
77+ # TODO: Support other streaming modes
78+ raise ValueError (f"Streaming mode { streaming } is not supported." )
7179 super ().__init__ ()
7280 if kwargs :
7381 warnings .warn (f"Unused arguments for agent initialization: { kwargs } " )
@@ -118,6 +126,27 @@ async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args
118126 List of responses.
119127 """
120128 return await self .llm_engine .generate_async (messages_list_or_inputs , ** args )
129+
130+ async def generate_streaming (self , messages_list_or_inputs : List [List [Dict ]], streaming_callback = None , ** args ):
131+ """
132+ Generate responses with streaming support. This method yields response chunks as they are generated.
133+
134+ Args:
135+ messages_list_or_inputs: List of messages to generate responses for.
136+ streaming_callback: Optional callback function for streaming chunks.
137+ **args: Additional arguments for generation.
138+
139+ Yields:
140+ str: Response chunks as they are generated.
141+ """
142+ if hasattr (self .llm_engine , 'generate_streaming' ):
143+ async for chunk in self .llm_engine .generate_streaming (messages_list_or_inputs , streaming_callback = streaming_callback , ** args ):
144+ yield chunk
145+ else :
146+ # Fallback to non-streaming generation
147+ responses = await self .generate_async (messages_list_or_inputs , ** args )
148+ for response in responses :
149+ yield response
121150
122151 @property
123152 def timing_data (self ):
0 commit comments