Skip to content

Commit 01279ad

Browse files
ds-sebastianchwilczynski“ds-sebastianchwilczynski”
andauthored
feat: support run_streaming for Agent (#650)
Co-authored-by: “ds-sebastianchwilczynski” <“[email protected]”>
1 parent ffceb1c commit 01279ad

File tree

8 files changed

+263
-159
lines changed

8 files changed

+263
-159
lines changed

examples/agents/a2a/agent_orchestrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def execute_agent_task(self, task_index: int) -> str:
140140

141141
tool_calls = None
142142
if result.tool_calls:
143-
tool_calls = [{"name": tc.name, "arguments": tc.arguments, "output": tc.output} for tc in result.tool_calls]
143+
tool_calls = [{"name": tc.name, "arguments": tc.arguments, "output": tc.result} for tc in result.tool_calls]
144144

145145
return json.dumps(
146146
{

packages/ragbits-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Update Agent run method docstring (#565)
88
- Fix AgentResult typing (#600)
99
- Support history handling in Agent (#648)
10+
- Support run_streaming in Agent (#650)
1011
- Support A2A protocol (#649)
1112

1213
## 1.0.0 (2025-06-04)

packages/ragbits-agents/src/ragbits/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from ragbits.agents._main import Agent, AgentOptions, AgentResult, ToolCallResult
1+
from ragbits.agents._main import Agent, AgentOptions, AgentResult, AgentResultStreaming, ToolCallResult
22
from ragbits.agents.types import QuestionAnswerAgent, QuestionAnswerPromptInput, QuestionAnswerPromptOutput
33

44
__all__ = [
55
"Agent",
66
"AgentOptions",
77
"AgentResult",
8+
"AgentResultStreaming",
89
"QuestionAnswerAgent",
910
"QuestionAnswerPromptInput",
1011
"QuestionAnswerPromptOutput",

packages/ragbits-agents/src/ragbits/agents/_main.py

Lines changed: 158 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
from collections.abc import Callable
1+
from collections.abc import AsyncGenerator, AsyncIterator, Callable
22
from copy import deepcopy
33
from dataclasses import dataclass
44
from inspect import getdoc, iscoroutinefunction
5-
from types import ModuleType
5+
from types import ModuleType, SimpleNamespace
66
from typing import Any, ClassVar, Generic, cast, overload
77

88
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
99

1010
from ragbits import agents
1111
from ragbits.agents.exceptions import (
1212
AgentInvalidPromptInputError,
13+
AgentToolExecutionError,
1314
AgentToolNotAvailableError,
1415
AgentToolNotSupportedError,
1516
)
1617
from ragbits.core.audit.traces import trace
17-
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata
18+
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, ToolCall
1819
from ragbits.core.options import Options
19-
from ragbits.core.prompt.base import ChatFormat, SimplePrompt
20+
from ragbits.core.prompt.base import BasePrompt, ChatFormat, SimplePrompt
2021
from ragbits.core.prompt.prompt import Prompt, PromptInputT, PromptOutputT
2122
from ragbits.core.types import NOT_GIVEN, NotGiven
2223
from ragbits.core.utils.config_handling import ConfigurableComponent
@@ -28,9 +29,10 @@ class ToolCallResult:
2829
Result of the tool call.
2930
"""
3031

32+
id: str
3133
name: str
3234
arguments: dict
33-
output: Any
35+
result: Any
3436

3537

3638
@dataclass
@@ -58,11 +60,58 @@ class AgentOptions(Options, Generic[LLMClientOptionsT]):
5860
"""The options for the LLM."""
5961

6062

63+
class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]):
64+
"""
65+
An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned
66+
by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes,
67+
all items are available under the same names as in AgentResult class.
68+
"""
69+
70+
def __init__(self, generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt]):
71+
self._generator = generator
72+
self.content: str = ""
73+
self.tool_calls: list[ToolCallResult] | None = None
74+
self.metadata: dict = {}
75+
self.history: ChatFormat
76+
77+
def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult]:
78+
return self
79+
80+
async def __anext__(self) -> str | ToolCall | ToolCallResult:
81+
try:
82+
item = await self._generator.__anext__()
83+
match item:
84+
case str():
85+
self.content += item
86+
case ToolCall():
87+
pass
88+
case ToolCallResult():
89+
if self.tool_calls is None:
90+
self.tool_calls = []
91+
self.tool_calls.append(item)
92+
case BasePrompt():
93+
item.add_assistant_message(self.content)
94+
self.history = item.chat
95+
item = await self._generator.__anext__()
96+
item = cast(SimpleNamespace, item)
97+
item.result = {
98+
"content": self.content,
99+
"metadata": self.metadata,
100+
"tool_calls": self.tool_calls,
101+
}
102+
raise StopAsyncIteration
103+
case _:
104+
raise ValueError(f"Unexpected item: {item}")
105+
return item
106+
except StopAsyncIteration:
107+
raise
108+
109+
61110
class Agent(
62111
ConfigurableComponent[AgentOptions[LLMClientOptionsT]], Generic[LLMClientOptionsT, PromptInputT, PromptOutputT]
63112
):
64113
"""
65-
Agent class that orchestrates the LLM and the prompt.
114+
Agent class that orchestrates the LLM and the prompt, and can call tools.
66115
67116
Current implementation is highly experimental, and the API is subject to change.
68117
"""
@@ -107,7 +156,7 @@ def __init__(
107156
@overload
108157
async def run(
109158
self: "Agent[LLMClientOptionsT, None, PromptOutputT]",
110-
input: str,
159+
input: str | None = None,
111160
options: AgentOptions[LLMClientOptionsT] | None = None,
112161
) -> AgentResult[PromptOutputT]: ...
113162

@@ -118,26 +167,18 @@ async def run(
118167
options: AgentOptions[LLMClientOptionsT] | None = None,
119168
) -> AgentResult[PromptOutputT]: ...
120169

121-
@overload
122170
async def run(
123-
self: "Agent[LLMClientOptionsT, None, PromptOutputT]",
124-
options: AgentOptions[LLMClientOptionsT] | None = None,
125-
) -> AgentResult[PromptOutputT]: ...
126-
127-
async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
171+
self, input: str | PromptInputT | None = None, options: AgentOptions[LLMClientOptionsT] | None = None
172+
) -> AgentResult[PromptOutputT]:
128173
"""
129174
Run the agent. The method is experimental, inputs and outputs may change in the future.
130175
131176
Args:
132-
*args: Positional arguments corresponding to the overload signatures.
133-
- If provided, the first positional argument is interpreted as `input`.
134-
- If a second positional argument is provided, it is interpreted as `options`.
135-
**kwargs: Keyword arguments corresponding to the overload signatures.
136-
- `input`: The input for the agent run. Can be:
137-
- str: A string input that will be used as user message.
138-
- PromptInputT: Structured input for use with structured prompt classes.
139-
- None: No input. Only valid when a string prompt was provided during initialization.
140-
- `options`: The options for the agent run.
177+
input: The input for the agent run. Can be:
178+
- str: A string input that will be used as user message.
179+
- PromptInputT: Structured input for use with structured prompt classes.
180+
- None: No input. Only valid when a string prompt was provided during initialization.
181+
options: The options for the agent run.
141182
142183
Returns:
143184
The result of the agent run.
@@ -147,8 +188,7 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
147188
AgentToolNotAvailableError: If the selected tool is not available.
148189
AgentInvalidPromptInputError: If the prompt/input combination is invalid.
149190
"""
150-
input = cast(PromptInputT, args[0] if args else kwargs.get("input"))
151-
options = args[1] if len(args) > 1 else kwargs.get("options")
191+
input = cast(PromptInputT, input)
152192

153193
merged_options = (self.default_options | options) if options else self.default_options
154194
llm_options = merged_options.llm_options or None
@@ -170,29 +210,10 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
170210
break
171211

172212
for tool_call in response.tool_calls:
173-
if tool_call.type != "function":
174-
raise AgentToolNotSupportedError(tool_call.type)
175-
176-
if tool_call.name not in self.tools_mapping:
177-
raise AgentToolNotAvailableError(tool_call.name)
178-
179-
tool = self.tools_mapping[tool_call.name]
180-
tool_output = (
181-
await tool(**tool_call.arguments) if iscoroutinefunction(tool) else tool(**tool_call.arguments)
182-
)
183-
tool_calls.append(
184-
ToolCallResult(
185-
name=tool_call.name,
186-
arguments=tool_call.arguments,
187-
output=tool_output,
188-
)
189-
)
190-
prompt_with_history = prompt_with_history.add_tool_use_message(
191-
id=tool_call.id,
192-
name=tool_call.name,
193-
arguments=tool_call.arguments,
194-
result=tool_output,
195-
)
213+
result = await self._execute_tool(tool_call)
214+
tool_calls.append(result)
215+
216+
prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
196217

197218
outputs.result = {
198219
"content": response.content,
@@ -212,10 +233,76 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
212233
history=prompt_with_history.chat,
213234
)
214235

236+
@overload
237+
def run_streaming(
238+
self: "Agent[LLMClientOptionsT, None, PromptOutputT]",
239+
input: str | None = None,
240+
options: AgentOptions[LLMClientOptionsT] | None = None,
241+
) -> AgentResultStreaming: ...
242+
243+
@overload
244+
def run_streaming(
245+
self: "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]",
246+
input: PromptInputT,
247+
options: AgentOptions[LLMClientOptionsT] | None = None,
248+
) -> AgentResultStreaming: ...
249+
250+
def run_streaming(
251+
self, input: str | PromptInputT | None = None, options: AgentOptions[LLMClientOptionsT] | None = None
252+
) -> AgentResultStreaming:
253+
"""
254+
This method returns an `AgentResultStreaming` object that can be asynchronously
255+
iterated over. After the loop completes, all items are available under the same names as in AgentResult class.
256+
257+
Args:
258+
input: The input for the agent run.
259+
options: The options for the agent run.
260+
261+
Returns:
262+
A `StreamingResult` object for iteration and collection.
263+
264+
Raises:
265+
AgentToolNotSupportedError: If the selected tool type is not supported.
266+
AgentToolNotAvailableError: If the selected tool is not available.
267+
AgentInvalidPromptInputError: If the prompt/input combination is invalid.
268+
"""
269+
generator = self._stream_internal(input, options)
270+
return AgentResultStreaming(generator)
271+
272+
async def _stream_internal(
273+
self, input: str | PromptInputT | None = None, options: AgentOptions[LLMClientOptionsT] | None = None
274+
) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt]:
275+
input = cast(PromptInputT, input)
276+
merged_options = (self.default_options | options) if options else self.default_options
277+
llm_options = merged_options.llm_options or None
278+
279+
prompt_with_history = self._get_prompt_with_history(input)
280+
with trace(input=input, options=merged_options) as outputs:
281+
while True:
282+
returned_tool_call = False
283+
async for chunk in self.llm.generate_streaming(
284+
prompt=prompt_with_history,
285+
tools=list(self.tools_mapping.values()),
286+
options=llm_options,
287+
):
288+
yield chunk
289+
290+
if isinstance(chunk, ToolCall):
291+
result = await self._execute_tool(chunk)
292+
yield result
293+
prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
294+
returned_tool_call = True
295+
296+
if not returned_tool_call:
297+
break
298+
yield prompt_with_history
299+
if self.keep_history:
300+
self.history = prompt_with_history.chat
301+
yield outputs
302+
215303
def _get_prompt_with_history(self, input: PromptInputT) -> SimplePrompt | Prompt[PromptInputT, PromptOutputT]:
216304
curr_history = deepcopy(self.history)
217305
if isinstance(self.prompt, type) and issubclass(self.prompt, Prompt):
218-
# If we had actual instance we could just run add_user_message here
219306
if self.keep_history:
220307
self.prompt = self.prompt(input, curr_history)
221308
return self.prompt
@@ -248,6 +335,29 @@ def _get_prompt_with_history(self, input: PromptInputT) -> SimplePrompt | Prompt
248335

249336
return SimplePrompt(curr_history)
250337

338+
async def _execute_tool(self, tool_call: ToolCall) -> ToolCallResult:
339+
if tool_call.type != "function":
340+
raise AgentToolNotSupportedError(tool_call.type)
341+
342+
if tool_call.name not in self.tools_mapping:
343+
raise AgentToolNotAvailableError(tool_call.name)
344+
345+
tool = self.tools_mapping[tool_call.name]
346+
347+
try:
348+
tool_output = (
349+
await tool(**tool_call.arguments) if iscoroutinefunction(tool) else tool(**tool_call.arguments)
350+
)
351+
except Exception as e:
352+
raise AgentToolExecutionError(tool_call.name, e) from e
353+
354+
return ToolCallResult(
355+
id=tool_call.id,
356+
name=tool_call.name,
357+
arguments=tool_call.arguments,
358+
result=tool_output,
359+
)
360+
251361
def get_agent_card(
252362
self,
253363
name: str,
@@ -307,70 +417,3 @@ def _extract_agent_skill(func: Callable) -> AgentSkill:
307417
"""
308418
doc = getdoc(func) or ""
309419
return AgentSkill(name=func.__name__.replace("_", " ").title(), id=func.__name__, description=doc, tags=[])
310-
311-
# TODO: implement run_streaming method according to the comment - https://github.com/deepsense-ai/ragbits/pull/623#issuecomment-2970514478
312-
# @overload
313-
# def run_streaming(
314-
# self: "Agent[LLMClientOptionsT, PromptInputT, str]",
315-
# input: PromptInputT,
316-
# options: AgentOptions[LLMClientOptionsT] | None = None,
317-
# ) -> AsyncGenerator[str | ToolCall, None]: ...
318-
319-
# @overload
320-
# def run_streaming(
321-
# self: "Agent[LLMClientOptionsT, None, str]",
322-
# options: AgentOptions[LLMClientOptionsT] | None = None,
323-
# ) -> AsyncGenerator[str | ToolCall, None]: ...
324-
325-
# async def run_streaming(self, *args: Any, **kwargs: Any) -> AsyncGenerator[str | ToolCall, None]: # noqa: D417
326-
# """
327-
# Run the agent. The method is experimental, inputs and outputs may change in the future.
328-
329-
# Args:
330-
# input: The input for the agent run.
331-
# options: The options for the agent run.
332-
333-
# Yields:
334-
# Response text chunks or tool calls from the Agent.
335-
# """
336-
# input = cast(PromptInputT, args[0] if args else kwargs.get("input"))
337-
# options = args[1] if len(args) > 1 else kwargs.get("options")
338-
339-
# merged_options = (self.default_options | options) if options else self.default_options
340-
# tools = merged_options.tools or None
341-
# llm_options = merged_options.llm_options or None
342-
343-
# prompt = self.prompt(input)
344-
# tools_mapping = {} if not tools else {f.__name__: f for f in tools}
345-
346-
# while True:
347-
# returned_tool_call = False
348-
# async for chunk in self.llm.generate_streaming(
349-
# prompt=prompt,
350-
# tools=tools, # type: ignore
351-
# options=llm_options,
352-
# ):
353-
# yield chunk
354-
355-
# if isinstance(chunk, ToolCall):
356-
# if chunk.type != "function":
357-
# raise AgentToolNotSupportedError(chunk.type)
358-
359-
# if chunk.name not in tools_mapping:
360-
# raise AgentToolNotAvailableError(chunk.name)
361-
362-
# tool = tools_mapping[chunk.name]
363-
# tool_output = (
364-
# await tool(**chunk.arguments) if iscoroutinefunction(tool) else tool(**chunk.arguments)
365-
# )
366-
367-
# prompt = prompt.add_tool_use_message(
368-
# tool_call_id=chunk.id,
369-
# tool_name=chunk.name,
370-
# tool_arguments=chunk.arguments,
371-
# tool_call_result=tool_output,
372-
# )
373-
# returned_tool_call = True
374-
375-
# if not returned_tool_call:
376-
# break

0 commit comments

Comments
 (0)