Skip to content

Commit 75dc9b1

Browse files
authored
Merge pull request #2 from Agent-One-Lab/streaming
Add Streaming capability for rollout.
2 parents c98b2b3 + 4a49bad commit 75dc9b1

File tree

15 files changed

+1147
-203
lines changed

15 files changed

+1147
-203
lines changed

agents/agents/agents/agent_base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import transformers
1616
import warnings
17+
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
1718
try:
1819
from verl.protocol import DataProto
1920
except ImportError:
@@ -43,6 +44,7 @@ def __init__(
4344
log_file: str = "agent",
4445
project_name: str = None,
4546
run_name: str = None,
47+
streaming: str = "console",
4648
**kwargs # To pass other unused arguments
4749
):
4850
"""
@@ -68,6 +70,12 @@ def __init__(
6870
self.jinja_template = get_template(self.template).jinja_template()
6971
self.project_name = project_name
7072
self.run_name = run_name
73+
self.streaming_manager = StreamingManager()
74+
if streaming == "console":
75+
self.streaming_manager.add_observer(ConsoleStreamObserver())
76+
else:
77+
# TODO: Support other streaming modes
78+
raise ValueError(f"Streaming mode {streaming} is not supported.")
7179
super().__init__()
7280
if kwargs:
7381
warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
@@ -118,6 +126,27 @@ async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args
118126
List of responses.
119127
"""
120128
return await self.llm_engine.generate_async(messages_list_or_inputs, **args)
129+
130+
async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], streaming_callback=None, **args):
131+
"""
132+
Generate responses with streaming support. This method yields response chunks as they are generated.
133+
134+
Args:
135+
messages_list_or_inputs: List of messages to generate responses for.
136+
streaming_callback: Optional callback function for streaming chunks.
137+
**args: Additional arguments for generation.
138+
139+
Yields:
140+
str: Response chunks as they are generated.
141+
"""
142+
if hasattr(self.llm_engine, 'generate_streaming'):
143+
async for chunk in self.llm_engine.generate_streaming(messages_list_or_inputs, streaming_callback=streaming_callback, **args):
144+
yield chunk
145+
else:
146+
# Fallback to non-streaming generation
147+
responses = await self.generate_async(messages_list_or_inputs, **args)
148+
for response in responses:
149+
yield response
121150

122151
@property
123152
def timing_data(self):

0 commit comments

Comments
 (0)