Skip to content

Commit de5a382

Browse files
maxpillmhordynski
authored andcommitted
feat(agents): AgentDependencies in AgentRunContext #780 (#781)
1 parent 81b96bb commit de5a382

File tree

7 files changed

+175
-12
lines changed

7 files changed

+175
-12
lines changed

docs/how-to/agents/define_and_use_agents.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,41 @@ In this scenario, the agent recognizes that the follow-up question "What about T
8484
AgentResult(content='The current temperature in Tokyo is 10°C.', ...)
8585
```
8686

87+
## Binding dependencies via AgentRunContext
88+
You can bind your external dependencies before the first access and safely use them in tools. After first attribute lookup, the dependencies container freezes to prevent mutation during a run.
89+
90+
```python
91+
from dataclasses import dataclass
92+
from ragbits.agents import Agent, AgentRunContext
93+
from ragbits.core.llms.mock import MockLLM, MockLLMOptions
94+
95+
@dataclass
96+
class Deps:
97+
api_host: str
98+
99+
def get_api_host(context: AgentRunContext | None) -> str:
100+
"""Return the API host taken from the bound dependencies in context."""
101+
assert context is not None
102+
return context.deps.api_host
103+
104+
async def main() -> None:
105+
llm = MockLLM(
106+
default_options=MockLLMOptions(
107+
response="Using dependencies from context.",
108+
tool_calls=[{"name": "get_api_host", "arguments": "{}", "id": "example", "type": "function"}],
109+
)
110+
)
111+
agent = Agent(llm=llm, prompt="Retrieve API host", tools=[get_api_host])
112+
113+
context = AgentRunContext()
114+
context.deps.value = Deps(api_host="https://api.local")
115+
116+
result = await agent.run("What host are we using?", context=context)
117+
print(result.tool_calls[0].result)
118+
```
119+
120+
See the runnable example in `examples/agents/dependencies.py`.
121+
87122
## Streaming agent responses
88123
For use cases where you want to process partial outputs from the LLM as they arrive (e.g., in chat UIs), the [`Agent`][ragbits.agents.Agent] class supports streaming through the `run_streaming()` method.
89124

examples/agents/dependencies.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import dataclass
2+
3+
from ragbits.agents import Agent, AgentRunContext
4+
from ragbits.core.llms.mock import MockLLM, MockLLMOptions
5+
6+
7+
@dataclass
8+
class Deps:
9+
"""Simple dependency container used by the agent tools."""
10+
11+
api_host: str
12+
13+
14+
def get_api_host(context: AgentRunContext | None) -> str:
15+
"""Return the API host taken from the bound dependencies in context."""
16+
if context is None:
17+
raise ValueError("Missing AgentRunContext")
18+
return context.deps.api_host
19+
20+
21+
async def main() -> None:
22+
"""Bind dependencies, run the agent, and access them inside a tool."""
23+
llm = MockLLM(
24+
default_options=MockLLMOptions(
25+
response="Using dependencies from context.",
26+
tool_calls=[
27+
{
28+
"name": "get_api_host",
29+
"arguments": "{}",
30+
"id": "example",
31+
"type": "function",
32+
}
33+
],
34+
)
35+
)
36+
37+
agent = Agent(llm=llm, prompt="Retrieve API host", tools=[get_api_host])
38+
39+
context = AgentRunContext()
40+
context.deps.value = Deps(api_host="https://api.local")
41+
42+
result = await agent.run("What host are we using?", context=context)
43+
print(result.tool_calls[0].result)
44+
45+
46+
if __name__ == "__main__":
47+
import asyncio
48+
49+
asyncio.run(main())

packages/ragbits-agents/CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
## Unreleased
44
- Add tool_choice parameter to agent interface (#738)
55
- Add PydanticAI agents support (#755)
6-
- Add unique ID to each agent instance for better tracking and identification
7-
- Add audit tracing for tool calls to improve observability and debugging
6+
- Add unique ID to each agent instance for better tracking and identification (#790)
7+
- Add audit tracing for tool calls to improve observability and debugging (#790)
8+
- Add AgentDependencies type (#781)
89

910
## 1.2.2 (2025-08-08)
1011

packages/ragbits-agents/src/ragbits/agents/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1-
from ragbits.agents._main import Agent, AgentOptions, AgentResult, AgentResultStreaming, AgentRunContext, ToolCallResult
1+
from ragbits.agents._main import (
2+
Agent,
3+
AgentDependencies,
4+
AgentOptions,
5+
AgentResult,
6+
AgentResultStreaming,
7+
AgentRunContext,
8+
ToolCallResult,
9+
)
210
from ragbits.agents.types import QuestionAnswerAgent, QuestionAnswerPromptInput, QuestionAnswerPromptOutput
311

412
__all__ = [
513
"Agent",
14+
"AgentDependencies",
615
"AgentOptions",
716
"AgentResult",
817
"AgentResultStreaming",

packages/ragbits-agents/src/ragbits/agents/_main.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import timedelta
88
from inspect import iscoroutinefunction
99
from types import ModuleType, SimpleNamespace
10-
from typing import ClassVar, Generic, cast, overload
10+
from typing import ClassVar, Generic, TypeVar, cast, overload
1111

1212
from pydantic import BaseModel, Field
1313
from typing_extensions import Self
@@ -82,11 +82,70 @@ class AgentOptions(Options, Generic[LLMClientOptionsT]):
8282
or None, agent will run forever"""
8383

8484

85-
class AgentRunContext(BaseModel):
85+
DepsT = TypeVar("DepsT")
86+
87+
88+
class AgentDependencies(BaseModel, Generic[DepsT]):
8689
"""
87-
Context for the agent run.
90+
Container for agent runtime dependencies.
91+
92+
Becomes immutable after first attribute access.
8893
"""
8994

95+
model_config = {"arbitrary_types_allowed": True}
96+
97+
_frozen: bool
98+
_value: DepsT | None
99+
100+
def __init__(self, value: DepsT | None = None) -> None:
101+
super().__init__()
102+
self._value = value
103+
self._frozen = False
104+
105+
def __setattr__(self, name: str, value: object) -> None:
106+
is_frozen = False
107+
if name != "_frozen":
108+
try:
109+
is_frozen = object.__getattribute__(self, "_frozen")
110+
except AttributeError:
111+
is_frozen = False
112+
113+
if is_frozen and name not in {"_frozen"}:
114+
raise RuntimeError("Dependencies are immutable after first access")
115+
116+
super().__setattr__(name, value)
117+
118+
@property
119+
def value(self) -> DepsT | None:
120+
return self._value
121+
122+
@value.setter
123+
def value(self, value: DepsT) -> None:
124+
if self._frozen:
125+
raise RuntimeError("Dependencies are immutable after first access")
126+
self._value = value
127+
128+
def _freeze(self) -> None:
129+
if not self._frozen:
130+
self._frozen = True
131+
132+
def __getattr__(self, name: str) -> object:
133+
value = object.__getattribute__(self, "_value")
134+
if value is None:
135+
raise AttributeError(name)
136+
self._freeze()
137+
return getattr(value, name)
138+
139+
def __contains__(self, key: str) -> bool:
140+
value = object.__getattribute__(self, "_value")
141+
return hasattr(value, key) if value is not None else False
142+
143+
144+
class AgentRunContext(BaseModel, Generic[DepsT]):
145+
"""Context for the agent run."""
146+
147+
deps: AgentDependencies[DepsT] = Field(default_factory=lambda: AgentDependencies())
148+
"""Container for external dependencies."""
90149
usage: Usage = Field(default_factory=Usage)
91150
"""The usage of the agent."""
92151

packages/ragbits-agents/tests/unit/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ async def test_agent_run_tools_with_context(
325325

326326
@pytest.mark.parametrize("method", [_run, _run_streaming])
327327
async def test_agent_run_context_is_updated(llm_without_tool_call: MockLLM, method: Callable):
328-
context = AgentRunContext()
328+
context: AgentRunContext = AgentRunContext()
329329
agent: Agent = Agent(
330330
llm=llm_without_tool_call,
331331
prompt="NOT IMPORTANT",

packages/ragbits-core/src/ragbits/core/utils/function_schema.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from collections.abc import Callable, Generator
88
from types import UnionType
9-
from typing import Any, get_args, get_origin, get_type_hints
9+
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
1010

1111
from griffe import Docstring, DocstringSectionKind
1212
from pydantic import BaseModel, Field, create_model
@@ -201,10 +201,20 @@ def _is_context_variable(ann: Any) -> bool: # noqa: ANN401
201201
"""
202202
Check if a type annotation is AgentRunContext.
203203
"""
204-
if hasattr(ann, "__name__") and ann.__name__ == "AgentRunContext":
204+
if ann is None:
205+
return False
206+
207+
origin = get_origin(ann)
208+
if origin is not None:
209+
if getattr(origin, "__name__", None) == "AgentRunContext":
210+
return True
211+
if origin in (Union, UnionType):
212+
return any(_is_context_variable(arg) for arg in get_args(ann))
213+
if origin is Annotated:
214+
args = get_args(ann)
215+
return bool(args) and _is_context_variable(args[0])
216+
217+
if getattr(ann, "__name__", None) == "AgentRunContext":
205218
return True
206219

207-
if get_origin(ann) is UnionType:
208-
return any(_is_context_variable(arg) for arg in get_args(ann))
209-
210220
return isinstance(ann, str) and "AgentRunContext" in ann

0 commit comments

Comments
 (0)