11from __future__ import annotations
2- from typing import Callable
3- from collections .abc import Iterator , Sequence
2+
3+ from collections .abc import Callable , Iterator , Sequence
44from contextlib import contextmanager
5- from typing import Any , Never , overload
6- from typing_extensions import Generic
5+ from typing import Any , Generic , Never , overload
6+
7+ from restate import Context , TerminalError
78
89from pydantic_ai import models
910from pydantic_ai ._run_context import AgentDepsT
1920from pydantic_ai .toolsets .function import FunctionToolset
2021from pydantic_ai .usage import RunUsage , UsageLimits
2122
22- from restate import Context , TerminalError
23-
2423from ._model import RestateModelWrapper
2524from ._toolset import RestateContextRunToolSet
2625
27- class RestateAgentProvider (Generic [AgentDepsT , OutputDataT ]):
2826
29- def __init__ (self ,
30- wrapped : AbstractAgent [AgentDepsT , OutputDataT ],
31- * ,
32- max_attempts : int = 3 ):
27+ class RestateAgentProvider (Generic [AgentDepsT , OutputDataT ]):
28+ def __init__ (self , wrapped : AbstractAgent [AgentDepsT , OutputDataT ], * , max_attempts : int = 3 ):
3329 if not isinstance (wrapped .model , Model ):
3430 raise TerminalError (
3531 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
@@ -41,19 +37,18 @@ def __init__(self,
4137 self .max_attempts = max_attempts
4238
4339 def create_agent (self , context : Context ) -> AbstractAgent [AgentDepsT , OutputDataT ]:
44- """
45- Create an agent instance with the given Restate context.
46-
40+ """Create an agent instance with the given Restate context.
41+
4742 Use this method to create an agent that is tied to a specific Restate context.
4843 With this agent, all operations will be executed within the provided context,
4944 enabling features like retries and durable steps.
5045 Note that the agent will automatically wrap tool calls with restate's `ctx.run()`.
51-
46+
5247 Example:
5348 ```python
5449 ...
5550 agent_provider = RestateAgentProvider(weather_agent)
56-
51+
5752 weather = restate.Service('weather')
5853
5954 @weather.handler()
@@ -63,28 +58,33 @@ async def get_weather(ctx: restate.Context, city: str):
6358 return result.output
6459 ...
6560 ```
66-
61+
6762 Args:
6863 context: The Restate context to use for the agent.
6964 auto_wrap_tool_calls: Whether to automatically wrap tool calls with restate's ctx.run() (durable step), True by default.
65+
7066 Returns:
7167 A RestateAgent instance that uses the provided context for its operations.
7268 """
73- get_context : Callable [[AgentDepsT ], Context ] = lambda _unused : context
69+
70+ def get_context (_unused : AgentDepsT ) -> Context :
71+ return context
72+
7473 builder = self
7574 return RestateAgent (builder = builder , get_context = get_context , auto_wrap_tools = True )
7675
77- def create_agent_with_advanced_tools (self , get_context : Callable [[AgentDepsT ], Context ]) -> AbstractAgent [AgentDepsT , OutputDataT ]:
78- """
79- Create an agent instance that is able to obtain Restate context from its dependencies.
80-
76+ def create_agent_with_advanced_tools (
77+ self , get_context : Callable [[AgentDepsT ], Context ]
78+ ) -> AbstractAgent [AgentDepsT , OutputDataT ]:
79+ """Create an agent instance that is able to obtain Restate context from its dependencies.
80+
8181 Use this method, if you are planning to use restate's context inside the tools (for rpc, timers, multi step tools etc.)
8282 To obtain a context inside a tool you can add a dependency that has a `restate_context` attribute, and provide a `get_context` extractor
8383 function that extracts the context from the dependencies at runtime.
8484
8585 Note: that the agent will NOT automatically wrap tool calls with restate's `ctx.run()`
8686 since the tools may use the context in different ways.
87-
87+
8888 Example:
8989 ```python
9090 ...
@@ -93,9 +93,9 @@ def create_agent_with_advanced_tools(self, get_context: Callable[[AgentDepsT], C
9393 WeatherDeps:
9494 ...
9595 restate_context: Context
96-
97- weather_agent = Agent(..., deps_type=WeatherDeps, ...)
98-
96+
97+ weather_agent = Agent(..., deps_type=WeatherDeps, ...)
98+
9999 @weather_agent.tool
100100 async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
101101 restate_context = ctx.deps.restate_context
@@ -104,7 +104,7 @@ async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -
104104 return LatLng(lat, lng)
105105
106106 agent = RestateAgentProvider(weather_agent).create_agent_from_deps(lambda deps: deps.restate_context)
107-
107+
108108 weather = restate.Service('weather')
109109
110110 @weather.handler()
@@ -113,19 +113,21 @@ async def get_weather(ctx: restate.Context, city: str):
113113 return result.output
114114 ...
115115 ```
116-
116+
117117 Args:
118118 get_context: A callable that extracts the Restate context from the agent's dependencies.
119+
119120 Returns:
120121 A RestateAgent instance that uses the provided context extractor to obtain the Restate context at runtime.
121-
122+
122123 """
123124 builder = self
124125 return RestateAgent (builder = builder , get_context = get_context , auto_wrap_tools = False )
125-
126+
126127
127128class RestateAgent (WrapperAgent [AgentDepsT , OutputDataT ]):
128129 """An agent that integrates with the Restate framework for resilient applications."""
130+
129131 def __init__ (
130132 self ,
131133 builder : RestateAgentProvider [AgentDepsT , OutputDataT ],
@@ -147,7 +149,8 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
147149 return RestateContextRunToolSet (toolset , context )
148150 try :
149151 from pydantic_ai .mcp import MCPServer
150- from ._toolset import RestateMCPServer
152+
153+ from ._toolset import RestateMCPServer
151154 except ImportError :
152155 pass
153156 else :
@@ -163,7 +166,7 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
163166 self .sequential_tool_calls (),
164167 ):
165168 yield
166-
169+
167170 @overload
168171 async def run (
169172 self ,
@@ -253,8 +256,10 @@ async def main():
253256 The result of the run.
254257 """
255258 if model is not None :
256- raise TerminalError ('An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.' )
257- context = self ._get_context (deps )
259+ raise TerminalError (
260+ 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
261+ )
262+ context = self ._get_context (deps )
258263 with self ._restate_overrides (context ):
259264 return await super (WrapperAgent , self ).run (
260265 user_prompt = user_prompt ,
@@ -270,9 +275,3 @@ async def main():
270275 toolsets = toolsets ,
271276 event_stream_handler = event_stream_handler ,
272277 )
273-
274-
275-
276-
277-
278-
0 commit comments