1
- from collections .abc import Callable
1
+ from collections .abc import AsyncGenerator , AsyncIterator , Callable
2
2
from copy import deepcopy
3
3
from dataclasses import dataclass
4
4
from inspect import getdoc , iscoroutinefunction
5
- from types import ModuleType
5
+ from types import ModuleType , SimpleNamespace
6
6
from typing import Any , ClassVar , Generic , cast , overload
7
7
8
8
from a2a .types import AgentCapabilities , AgentCard , AgentSkill
9
9
10
10
from ragbits import agents
11
11
from ragbits .agents .exceptions import (
12
12
AgentInvalidPromptInputError ,
13
+ AgentToolExecutionError ,
13
14
AgentToolNotAvailableError ,
14
15
AgentToolNotSupportedError ,
15
16
)
16
17
from ragbits .core .audit .traces import trace
17
- from ragbits .core .llms .base import LLM , LLMClientOptionsT , LLMResponseWithMetadata
18
+ from ragbits .core .llms .base import LLM , LLMClientOptionsT , LLMResponseWithMetadata , ToolCall
18
19
from ragbits .core .options import Options
19
- from ragbits .core .prompt .base import ChatFormat , SimplePrompt
20
+ from ragbits .core .prompt .base import BasePrompt , ChatFormat , SimplePrompt
20
21
from ragbits .core .prompt .prompt import Prompt , PromptInputT , PromptOutputT
21
22
from ragbits .core .types import NOT_GIVEN , NotGiven
22
23
from ragbits .core .utils .config_handling import ConfigurableComponent
@@ -28,9 +29,10 @@ class ToolCallResult:
28
29
Result of the tool call.
29
30
"""
30
31
32
+ id : str
31
33
name : str
32
34
arguments : dict
33
- output : Any
35
+ result : Any
34
36
35
37
36
38
@dataclass
@@ -58,11 +60,58 @@ class AgentOptions(Options, Generic[LLMClientOptionsT]):
58
60
"""The options for the LLM."""
59
61
60
62
63
+ class AgentResultStreaming (AsyncIterator [str | ToolCall | ToolCallResult ]):
64
+ """
65
+ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned
66
+ by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes,
67
+ all items are available under the same names as in AgentResult class.
68
+ """
69
+
70
+ def __init__ (self , generator : AsyncGenerator [str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt ]):
71
+ self ._generator = generator
72
+ self .content : str = ""
73
+ self .tool_calls : list [ToolCallResult ] | None = None
74
+ self .metadata : dict = {}
75
+ self .history : ChatFormat
76
+
77
+ def __aiter__ (self ) -> AsyncIterator [str | ToolCall | ToolCallResult ]:
78
+ return self
79
+
80
+ async def __anext__ (self ) -> str | ToolCall | ToolCallResult :
81
+ try :
82
+ item = await self ._generator .__anext__ ()
83
+ match item :
84
+ case str ():
85
+ self .content += item
86
+ case ToolCall ():
87
+ pass
88
+ case ToolCallResult ():
89
+ if self .tool_calls is None :
90
+ self .tool_calls = []
91
+ self .tool_calls .append (item )
92
+ case BasePrompt ():
93
+ item .add_assistant_message (self .content )
94
+ self .history = item .chat
95
+ item = await self ._generator .__anext__ ()
96
+ item = cast (SimpleNamespace , item )
97
+ item .result = {
98
+ "content" : self .content ,
99
+ "metadata" : self .metadata ,
100
+ "tool_calls" : self .tool_calls ,
101
+ }
102
+ raise StopAsyncIteration
103
+ case _:
104
+ raise ValueError (f"Unexpected item: { item } " )
105
+ return item
106
+ except StopAsyncIteration :
107
+ raise
108
+
109
+
61
110
class Agent (
62
111
ConfigurableComponent [AgentOptions [LLMClientOptionsT ]], Generic [LLMClientOptionsT , PromptInputT , PromptOutputT ]
63
112
):
64
113
"""
65
- Agent class that orchestrates the LLM and the prompt.
114
+ Agent class that orchestrates the LLM and the prompt, and can call tools .
66
115
67
116
Current implementation is highly experimental, and the API is subject to change.
68
117
"""
@@ -107,7 +156,7 @@ def __init__(
107
156
@overload
108
157
async def run (
109
158
self : "Agent[LLMClientOptionsT, None, PromptOutputT]" ,
110
- input : str ,
159
+ input : str | None = None ,
111
160
options : AgentOptions [LLMClientOptionsT ] | None = None ,
112
161
) -> AgentResult [PromptOutputT ]: ...
113
162
@@ -118,26 +167,18 @@ async def run(
118
167
options : AgentOptions [LLMClientOptionsT ] | None = None ,
119
168
) -> AgentResult [PromptOutputT ]: ...
120
169
121
- @overload
122
170
async def run (
123
- self : "Agent[LLMClientOptionsT, None, PromptOutputT]" ,
124
- options : AgentOptions [LLMClientOptionsT ] | None = None ,
125
- ) -> AgentResult [PromptOutputT ]: ...
126
-
127
- async def run (self , * args : Any , ** kwargs : Any ) -> AgentResult [PromptOutputT ]:
171
+ self , input : str | PromptInputT | None = None , options : AgentOptions [LLMClientOptionsT ] | None = None
172
+ ) -> AgentResult [PromptOutputT ]:
128
173
"""
129
174
Run the agent. The method is experimental, inputs and outputs may change in the future.
130
175
131
176
Args:
132
- *args: Positional arguments corresponding to the overload signatures.
133
- - If provided, the first positional argument is interpreted as `input`.
134
- - If a second positional argument is provided, it is interpreted as `options`.
135
- **kwargs: Keyword arguments corresponding to the overload signatures.
136
- - `input`: The input for the agent run. Can be:
137
- - str: A string input that will be used as user message.
138
- - PromptInputT: Structured input for use with structured prompt classes.
139
- - None: No input. Only valid when a string prompt was provided during initialization.
140
- - `options`: The options for the agent run.
177
+ input: The input for the agent run. Can be:
178
+ - str: A string input that will be used as user message.
179
+ - PromptInputT: Structured input for use with structured prompt classes.
180
+ - None: No input. Only valid when a string prompt was provided during initialization.
181
+ options: The options for the agent run.
141
182
142
183
Returns:
143
184
The result of the agent run.
@@ -147,8 +188,7 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
147
188
AgentToolNotAvailableError: If the selected tool is not available.
148
189
AgentInvalidPromptInputError: If the prompt/input combination is invalid.
149
190
"""
150
- input = cast (PromptInputT , args [0 ] if args else kwargs .get ("input" ))
151
- options = args [1 ] if len (args ) > 1 else kwargs .get ("options" )
191
+ input = cast (PromptInputT , input )
152
192
153
193
merged_options = (self .default_options | options ) if options else self .default_options
154
194
llm_options = merged_options .llm_options or None
@@ -170,29 +210,10 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
170
210
break
171
211
172
212
for tool_call in response .tool_calls :
173
- if tool_call .type != "function" :
174
- raise AgentToolNotSupportedError (tool_call .type )
175
-
176
- if tool_call .name not in self .tools_mapping :
177
- raise AgentToolNotAvailableError (tool_call .name )
178
-
179
- tool = self .tools_mapping [tool_call .name ]
180
- tool_output = (
181
- await tool (** tool_call .arguments ) if iscoroutinefunction (tool ) else tool (** tool_call .arguments )
182
- )
183
- tool_calls .append (
184
- ToolCallResult (
185
- name = tool_call .name ,
186
- arguments = tool_call .arguments ,
187
- output = tool_output ,
188
- )
189
- )
190
- prompt_with_history = prompt_with_history .add_tool_use_message (
191
- id = tool_call .id ,
192
- name = tool_call .name ,
193
- arguments = tool_call .arguments ,
194
- result = tool_output ,
195
- )
213
+ result = await self ._execute_tool (tool_call )
214
+ tool_calls .append (result )
215
+
216
+ prompt_with_history = prompt_with_history .add_tool_use_message (** result .__dict__ )
196
217
197
218
outputs .result = {
198
219
"content" : response .content ,
@@ -212,10 +233,76 @@ async def run(self, *args: Any, **kwargs: Any) -> AgentResult[PromptOutputT]:
212
233
history = prompt_with_history .chat ,
213
234
)
214
235
236
+ @overload
237
+ def run_streaming (
238
+ self : "Agent[LLMClientOptionsT, None, PromptOutputT]" ,
239
+ input : str | None = None ,
240
+ options : AgentOptions [LLMClientOptionsT ] | None = None ,
241
+ ) -> AgentResultStreaming : ...
242
+
243
+ @overload
244
+ def run_streaming (
245
+ self : "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]" ,
246
+ input : PromptInputT ,
247
+ options : AgentOptions [LLMClientOptionsT ] | None = None ,
248
+ ) -> AgentResultStreaming : ...
249
+
250
+ def run_streaming (
251
+ self , input : str | PromptInputT | None = None , options : AgentOptions [LLMClientOptionsT ] | None = None
252
+ ) -> AgentResultStreaming :
253
+ """
254
+ This method returns an `AgentResultStreaming` object that can be asynchronously
255
+ iterated over. After the loop completes, all items are available under the same names as in AgentResult class.
256
+
257
+ Args:
258
+ input: The input for the agent run.
259
+ options: The options for the agent run.
260
+
261
+ Returns:
262
+ A `StreamingResult` object for iteration and collection.
263
+
264
+ Raises:
265
+ AgentToolNotSupportedError: If the selected tool type is not supported.
266
+ AgentToolNotAvailableError: If the selected tool is not available.
267
+ AgentInvalidPromptInputError: If the prompt/input combination is invalid.
268
+ """
269
+ generator = self ._stream_internal (input , options )
270
+ return AgentResultStreaming (generator )
271
+
272
+ async def _stream_internal (
273
+ self , input : str | PromptInputT | None = None , options : AgentOptions [LLMClientOptionsT ] | None = None
274
+ ) -> AsyncGenerator [str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt ]:
275
+ input = cast (PromptInputT , input )
276
+ merged_options = (self .default_options | options ) if options else self .default_options
277
+ llm_options = merged_options .llm_options or None
278
+
279
+ prompt_with_history = self ._get_prompt_with_history (input )
280
+ with trace (input = input , options = merged_options ) as outputs :
281
+ while True :
282
+ returned_tool_call = False
283
+ async for chunk in self .llm .generate_streaming (
284
+ prompt = prompt_with_history ,
285
+ tools = list (self .tools_mapping .values ()),
286
+ options = llm_options ,
287
+ ):
288
+ yield chunk
289
+
290
+ if isinstance (chunk , ToolCall ):
291
+ result = await self ._execute_tool (chunk )
292
+ yield result
293
+ prompt_with_history = prompt_with_history .add_tool_use_message (** result .__dict__ )
294
+ returned_tool_call = True
295
+
296
+ if not returned_tool_call :
297
+ break
298
+ yield prompt_with_history
299
+ if self .keep_history :
300
+ self .history = prompt_with_history .chat
301
+ yield outputs
302
+
215
303
def _get_prompt_with_history (self , input : PromptInputT ) -> SimplePrompt | Prompt [PromptInputT , PromptOutputT ]:
216
304
curr_history = deepcopy (self .history )
217
305
if isinstance (self .prompt , type ) and issubclass (self .prompt , Prompt ):
218
- # If we had actual instance we could just run add_user_message here
219
306
if self .keep_history :
220
307
self .prompt = self .prompt (input , curr_history )
221
308
return self .prompt
@@ -248,6 +335,29 @@ def _get_prompt_with_history(self, input: PromptInputT) -> SimplePrompt | Prompt
248
335
249
336
return SimplePrompt (curr_history )
250
337
338
+ async def _execute_tool (self , tool_call : ToolCall ) -> ToolCallResult :
339
+ if tool_call .type != "function" :
340
+ raise AgentToolNotSupportedError (tool_call .type )
341
+
342
+ if tool_call .name not in self .tools_mapping :
343
+ raise AgentToolNotAvailableError (tool_call .name )
344
+
345
+ tool = self .tools_mapping [tool_call .name ]
346
+
347
+ try :
348
+ tool_output = (
349
+ await tool (** tool_call .arguments ) if iscoroutinefunction (tool ) else tool (** tool_call .arguments )
350
+ )
351
+ except Exception as e :
352
+ raise AgentToolExecutionError (tool_call .name , e ) from e
353
+
354
+ return ToolCallResult (
355
+ id = tool_call .id ,
356
+ name = tool_call .name ,
357
+ arguments = tool_call .arguments ,
358
+ result = tool_output ,
359
+ )
360
+
251
361
def get_agent_card (
252
362
self ,
253
363
name : str ,
@@ -307,70 +417,3 @@ def _extract_agent_skill(func: Callable) -> AgentSkill:
307
417
"""
308
418
doc = getdoc (func ) or ""
309
419
return AgentSkill (name = func .__name__ .replace ("_" , " " ).title (), id = func .__name__ , description = doc , tags = [])
310
-
311
- # TODO: implement run_streaming method according to the comment - https://github.com/deepsense-ai/ragbits/pull/623#issuecomment-2970514478
312
- # @overload
313
- # def run_streaming(
314
- # self: "Agent[LLMClientOptionsT, PromptInputT, str]",
315
- # input: PromptInputT,
316
- # options: AgentOptions[LLMClientOptionsT] | None = None,
317
- # ) -> AsyncGenerator[str | ToolCall, None]: ...
318
-
319
- # @overload
320
- # def run_streaming(
321
- # self: "Agent[LLMClientOptionsT, None, str]",
322
- # options: AgentOptions[LLMClientOptionsT] | None = None,
323
- # ) -> AsyncGenerator[str | ToolCall, None]: ...
324
-
325
- # async def run_streaming(self, *args: Any, **kwargs: Any) -> AsyncGenerator[str | ToolCall, None]: # noqa: D417
326
- # """
327
- # Run the agent. The method is experimental, inputs and outputs may change in the future.
328
-
329
- # Args:
330
- # input: The input for the agent run.
331
- # options: The options for the agent run.
332
-
333
- # Yields:
334
- # Response text chunks or tool calls from the Agent.
335
- # """
336
- # input = cast(PromptInputT, args[0] if args else kwargs.get("input"))
337
- # options = args[1] if len(args) > 1 else kwargs.get("options")
338
-
339
- # merged_options = (self.default_options | options) if options else self.default_options
340
- # tools = merged_options.tools or None
341
- # llm_options = merged_options.llm_options or None
342
-
343
- # prompt = self.prompt(input)
344
- # tools_mapping = {} if not tools else {f.__name__: f for f in tools}
345
-
346
- # while True:
347
- # returned_tool_call = False
348
- # async for chunk in self.llm.generate_streaming(
349
- # prompt=prompt,
350
- # tools=tools, # type: ignore
351
- # options=llm_options,
352
- # ):
353
- # yield chunk
354
-
355
- # if isinstance(chunk, ToolCall):
356
- # if chunk.type != "function":
357
- # raise AgentToolNotSupportedError(chunk.type)
358
-
359
- # if chunk.name not in tools_mapping:
360
- # raise AgentToolNotAvailableError(chunk.name)
361
-
362
- # tool = tools_mapping[chunk.name]
363
- # tool_output = (
364
- # await tool(**chunk.arguments) if iscoroutinefunction(tool) else tool(**chunk.arguments)
365
- # )
366
-
367
- # prompt = prompt.add_tool_use_message(
368
- # tool_call_id=chunk.id,
369
- # tool_name=chunk.name,
370
- # tool_arguments=chunk.arguments,
371
- # tool_call_result=tool_output,
372
- # )
373
- # returned_tool_call = True
374
-
375
- # if not returned_tool_call:
376
- # break
0 commit comments