Skip to content

Commit e2f1b21

Browse files
committed
Initial support for restate
1 parent 79ef2bf commit e2f1b21

File tree

6 files changed

+470
-0
lines changed

6 files changed

+470
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._model import RestateModelWrapper
2+
from ._agent import RestateAgent, RestateAgentProvider
3+
from ._serde import PydanticTypeAdapter
4+
from ._toolset import RestateContextRunToolSet
5+
6+
__all__ = ['RestateModelWrapper', 'RestateAgent', 'RestateAgentProvider', 'PydanticTypeAdapter', 'RestateContextRunToolSet']
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
from collections.abc import Iterator, Sequence
4+
from contextlib import contextmanager
5+
from typing import Any, Never, overload
6+
from typing_extensions import Generic
7+
from dataclasses import dataclass
8+
9+
from pydantic_ai import models
10+
from pydantic_ai._run_context import AgentDepsT
11+
from pydantic_ai.agent.abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
12+
from pydantic_ai.agent.wrapper import WrapperAgent
13+
from pydantic_ai.messages import ModelMessage, UserContent
14+
from pydantic_ai.models import Model
15+
from pydantic_ai.output import OutputDataT, OutputSpec
16+
from pydantic_ai.run import AgentRunResult
17+
from pydantic_ai.settings import ModelSettings
18+
from pydantic_ai.tools import DeferredToolResults
19+
from pydantic_ai.toolsets.abstract import AbstractToolset
20+
from pydantic_ai.toolsets.function import FunctionToolset
21+
from pydantic_ai.usage import RunUsage, UsageLimits
22+
23+
from restate import Context, TerminalError
24+
25+
from ._model import RestateModelWrapper
26+
from ._toolset import RestateContextRunToolSet
27+
28+
class RestateAgentProvider(Generic[AgentDepsT, OutputDataT]):
29+
30+
def __init__(self,
31+
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
32+
*,
33+
max_attempts: int = 3):
34+
if not isinstance(wrapped.model, Model):
35+
raise TerminalError(
36+
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
37+
)
38+
# here we collect all the configuration that will be passed to the RestateAgent
39+
# the actual context will be provided at runtime.
40+
self.wrapped = wrapped
41+
self.max_attempts = max_attempts
42+
43+
def create_agent(self, context: Context) -> AbstractAgent[AgentDepsT, OutputDataT]:
44+
"""
45+
Create an agent instance with the given Restate context.
46+
47+
Use this method to create an agent that is tied to a specific Restate context.
48+
With this agent, all operations will be executed within the provided context,
49+
enabling features like retries and durable steps.
50+
Note that the agent will automatically wrap tool calls with restate's `ctx.run()`.
51+
52+
Example:
53+
```python
54+
...
55+
agent_provider = RestateAgentProvider(weather_agent)
56+
57+
weather = restate.Service('weather')
58+
59+
@weather.handler()
60+
async def get_weather(ctx: restate.Context, city: str):
61+
agent = agent_provider.create_agent_from_context(ctx)
62+
result = await agent.run(f'What is the weather in {city}?')
63+
return result.output
64+
...
65+
```
66+
67+
Args:
68+
context: The Restate context to use for the agent.
69+
auto_wrap_tool_calls: Whether to automatically wrap tool calls with restate's ctx.run() (durable step), True by default.
70+
Returns:
71+
A RestateAgent instance that uses the provided context for its operations.
72+
"""
73+
get_context = lambda _unused: context
74+
builder = self
75+
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=True)
76+
77+
def create_agent_with_advanced_tools(self, get_context: Callable[[AgentDepsT], Context]) -> AbstractAgent[AgentDepsT, OutputDataT]:
78+
"""
79+
Create an agent instance that is able to obtain Restate context from its dependencies.
80+
81+
Use this method, if you are planning to use restate's context inside the tools (for rpc, timers, multi step tools etc.)
82+
To obtain a context inside a tool you can add a dependency that has a `restate_context` attribute, and provide a `get_context` extractor
83+
function that extracts the context from the dependencies at runtime.
84+
85+
Note: that the agent will NOT automatically wrap tool calls with restate's `ctx.run()`
86+
since the tools may use the context in different ways.
87+
88+
Example:
89+
```python
90+
...
91+
92+
@dataclass
93+
WeatherDeps:
94+
...
95+
restate_context: Context
96+
97+
weather_agent = Agent(..., deps_type=WeatherDeps, ...)
98+
99+
@weather_agent.tool
100+
async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
101+
restate_context = ctx.deps.restate_context
102+
lat = await restate_context.run(...) # <---- note the direct usage of the restate context
103+
lng = await restate_context.run(...)
104+
return LatLng(lat, lng)
105+
106+
agent = RestateAgentProvider(weather_agent).create_agent_from_deps(lambda deps: deps.restate_context)
107+
108+
weather = restate.Service('weather')
109+
110+
@weather.handler()
111+
async def get_weather(ctx: restate.Context, city: str):
112+
result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
113+
return result.output
114+
...
115+
```
116+
117+
Args:
118+
get_context: A callable that extracts the Restate context from the agent's dependencies.
119+
Returns:
120+
A RestateAgent instance that uses the provided context extractor to obtain the Restate context at runtime.
121+
122+
"""
123+
builder = self
124+
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=False)
125+
126+
127+
class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
128+
"""An agent that integrates with the Restate framework for resilient applications."""
129+
def __init__(
130+
self,
131+
builder: RestateAgentProvider[AgentDepsT, OutputDataT],
132+
get_context: Callable[[AgentDepsT], Context],
133+
auto_wrap_tools: bool,
134+
):
135+
super().__init__(builder.wrapped)
136+
self._builder = builder
137+
self._get_context = get_context
138+
self._auto_wrap_tools = auto_wrap_tools
139+
140+
@contextmanager
141+
def _restate_overrides(self, context: Context) -> Iterator[None]:
142+
model = RestateModelWrapper(self._builder.wrapped.model, context, max_attempts=self._builder.max_attempts)
143+
144+
def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
145+
"""Set the Restate context for the toolset, wrapping tools if needed."""
146+
if isinstance(toolset, FunctionToolset) and self._auto_wrap_tools:
147+
return RestateContextRunToolSet(toolset, context)
148+
try:
149+
from pydantic_ai.mcp import MCPServer
150+
from ._toolset import RestateMCPServer
151+
except ImportError:
152+
pass
153+
else:
154+
if isinstance(toolset, MCPServer):
155+
return RestateMCPServer(toolset, context)
156+
157+
return toolset
158+
159+
toolsets = [toolset.visit_and_replace(set_context) for toolset in self._builder.wrapped.toolsets]
160+
161+
with (
162+
super().override(model=model, toolsets=toolsets, tools=[]),
163+
self.sequential_tool_calls(),
164+
):
165+
yield
166+
167+
@overload
168+
async def run(
169+
self,
170+
user_prompt: str | Sequence[UserContent] | None = None,
171+
*,
172+
output_type: None = None,
173+
message_history: list[ModelMessage] | None = None,
174+
deferred_tool_results: DeferredToolResults | None = None,
175+
model: models.Model | models.KnownModelName | str | None = None,
176+
deps: AgentDepsT = None,
177+
model_settings: ModelSettings | None = None,
178+
usage_limits: UsageLimits | None = None,
179+
usage: RunUsage | None = None,
180+
infer_name: bool = True,
181+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
182+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
183+
) -> AgentRunResult[OutputDataT]: ...
184+
185+
@overload
186+
async def run(
187+
self,
188+
user_prompt: str | Sequence[UserContent] | None = None,
189+
*,
190+
output_type: OutputSpec[RunOutputDataT],
191+
message_history: list[ModelMessage] | None = None,
192+
deferred_tool_results: DeferredToolResults | None = None,
193+
model: models.Model | models.KnownModelName | str | None = None,
194+
deps: AgentDepsT = None,
195+
model_settings: ModelSettings | None = None,
196+
usage_limits: UsageLimits | None = None,
197+
usage: RunUsage | None = None,
198+
infer_name: bool = True,
199+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
200+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
201+
) -> AgentRunResult[RunOutputDataT]: ...
202+
203+
async def run(
204+
self,
205+
user_prompt: str | Sequence[UserContent] | None = None,
206+
*,
207+
output_type: OutputSpec[RunOutputDataT] | None = None,
208+
message_history: list[ModelMessage] | None = None,
209+
deferred_tool_results: DeferredToolResults | None = None,
210+
model: models.Model | models.KnownModelName | str | None = None,
211+
deps: AgentDepsT = None,
212+
model_settings: ModelSettings | None = None,
213+
usage_limits: UsageLimits | None = None,
214+
usage: RunUsage | None = None,
215+
infer_name: bool = True,
216+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
217+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
218+
**_deprecated_kwargs: Never,
219+
) -> AgentRunResult[Any]:
220+
"""Run the agent with a user prompt in async mode.
221+
222+
This method builds an internal agent graph (using system prompts, tools and result schemas) and then
223+
runs the graph to completion. The result of the run is returned.
224+
225+
Example:
226+
```python
227+
from pydantic_ai import Agent
228+
229+
agent = Agent('openai:gpt-4o')
230+
231+
async def main():
232+
agent_run = await agent.run('What is the capital of France?')
233+
print(agent_run.output)
234+
#> The capital of France is Paris.
235+
```
236+
237+
Args:
238+
user_prompt: User input to start/continue the conversation.
239+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
240+
output validators since output validators would expect an argument that matches the agent's output type.
241+
message_history: History of the conversation so far.
242+
deferred_tool_results: Optional results for deferred tool calls in the message history.
243+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
244+
deps: Optional dependencies to use for this run.
245+
model_settings: Optional settings to use for this model's request.
246+
usage_limits: Optional limits on model request count or token usage.
247+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
248+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
249+
toolsets: Optional additional toolsets for this run.
250+
event_stream_handler: Optional event stream handler to use for this run.
251+
252+
Returns:
253+
The result of the run.
254+
"""
255+
if model is not None:
256+
raise TerminalError('An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.')
257+
context = self._get_context(deps)
258+
with self._restate_overrides(context):
259+
return await super(WrapperAgent, self).run(
260+
user_prompt=user_prompt,
261+
output_type=output_type,
262+
message_history=message_history,
263+
deferred_tool_results=deferred_tool_results,
264+
model=model,
265+
deps=deps,
266+
model_settings=model_settings,
267+
usage_limits=usage_limits,
268+
usage=usage,
269+
infer_name=infer_name,
270+
toolsets=toolsets,
271+
event_stream_handler=event_stream_handler,
272+
)
273+
274+
275+
276+
277+
278+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Any, Optional
2+
3+
from pydantic_ai.durable_exec.restate._serde import PydanticTypeAdapter
4+
from pydantic_ai.messages import ModelResponse
5+
from pydantic_ai.models import Model
6+
from pydantic_ai.models.wrapper import WrapperModel
7+
8+
from restate import Context, RunOptions
9+
10+
class RestateModelWrapper(WrapperModel):
11+
12+
def __init__(self,
13+
wrapped: Model,
14+
context: Context,
15+
max_attempts: Optional[int] = None):
16+
super().__init__(wrapped)
17+
self.options = RunOptions[ModelResponse](serde=PydanticTypeAdapter(ModelResponse), max_attempts=max_attempts)
18+
self.context = context
19+
20+
async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
21+
return await self.context.run_typed("Model call", self.wrapped.request, self.options, *args, **kwargs)
22+
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
import typing
3+
from restate.serde import Serde
4+
from pydantic import TypeAdapter
5+
6+
T = typing.TypeVar('T')
7+
8+
class PydanticTypeAdapter(Serde[T]):
9+
"""A serializer/deserializer for Pydantic models."""
10+
11+
def __init__(self, model_type: typing.Type[T]):
12+
"""
13+
Initializes a new instance of the PydanticTypeAdaptorSerde class.
14+
15+
Args:
16+
model_type (typing.Type[T]): The Pydantic model type to serialize/deserialize.
17+
"""
18+
self._model_type = TypeAdapter(model_type)
19+
20+
def deserialize(self, buf: bytes) -> typing.Optional[T]:
21+
"""
22+
Deserializes a bytearray to a Pydantic model.
23+
24+
Args:
25+
buf (bytearray): The bytearray to deserialize.
26+
27+
Returns:
28+
typing.Optional[T]: The deserialized Pydantic model.
29+
"""
30+
if not buf:
31+
return None
32+
return self._model_type.validate_json(buf.decode('utf-8')) # raises if invalid
33+
34+
def serialize(self, obj: typing.Optional[T]) -> bytes:
35+
"""
36+
Serializes a Pydantic model to a bytearray.
37+
38+
Args:
39+
obj (typing.Optional[T]): The Pydantic model to serialize.
40+
41+
Returns:
42+
bytes: The serialized bytearray.
43+
"""
44+
if obj is None:
45+
return b''
46+
tpe = TypeAdapter(type(obj))
47+
return tpe.dump_json(obj)
48+
49+

0 commit comments

Comments
 (0)