Skip to content

Commit cea3363

Browse files
committed
Apply formatting and lint
1 parent 3222a24 commit cea3363

File tree

5 files changed

+86
-92
lines changed

5 files changed

+86
-92
lines changed
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from ._model import RestateModelWrapper
21
from ._agent import RestateAgent, RestateAgentProvider
2+
from ._model import RestateModelWrapper
33
from ._serde import PydanticTypeAdapter
44
from ._toolset import RestateContextRunToolSet
55

6-
__all__ = ['RestateModelWrapper', 'RestateAgent', 'RestateAgentProvider', 'PydanticTypeAdapter', 'RestateContextRunToolSet']
6+
__all__ = [
7+
'RestateModelWrapper',
8+
'RestateAgent',
9+
'RestateAgentProvider',
10+
'PydanticTypeAdapter',
11+
'RestateContextRunToolSet',
12+
]

pydantic_ai_slim/pydantic_ai/durable_exec/restate/_agent.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
2-
from typing import Callable
3-
from collections.abc import Iterator, Sequence
2+
3+
from collections.abc import Callable, Iterator, Sequence
44
from contextlib import contextmanager
5-
from typing import Any, Never, overload
6-
from typing_extensions import Generic
5+
from typing import Any, Generic, Never, overload
6+
7+
from restate import Context, TerminalError
78

89
from pydantic_ai import models
910
from pydantic_ai._run_context import AgentDepsT
@@ -19,17 +20,12 @@
1920
from pydantic_ai.toolsets.function import FunctionToolset
2021
from pydantic_ai.usage import RunUsage, UsageLimits
2122

22-
from restate import Context, TerminalError
23-
2423
from ._model import RestateModelWrapper
2524
from ._toolset import RestateContextRunToolSet
2625

27-
class RestateAgentProvider(Generic[AgentDepsT, OutputDataT]):
2826

29-
def __init__(self,
30-
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
31-
*,
32-
max_attempts: int = 3):
27+
class RestateAgentProvider(Generic[AgentDepsT, OutputDataT]):
28+
def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT], *, max_attempts: int = 3):
3329
if not isinstance(wrapped.model, Model):
3430
raise TerminalError(
3531
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
@@ -41,19 +37,18 @@ def __init__(self,
4137
self.max_attempts = max_attempts
4238

4339
def create_agent(self, context: Context) -> AbstractAgent[AgentDepsT, OutputDataT]:
44-
"""
45-
Create an agent instance with the given Restate context.
46-
40+
"""Create an agent instance with the given Restate context.
41+
4742
Use this method to create an agent that is tied to a specific Restate context.
4843
With this agent, all operations will be executed within the provided context,
4944
enabling features like retries and durable steps.
5045
Note that the agent will automatically wrap tool calls with restate's `ctx.run()`.
51-
46+
5247
Example:
5348
```python
5449
...
5550
agent_provider = RestateAgentProvider(weather_agent)
56-
51+
5752
weather = restate.Service('weather')
5853
5954
@weather.handler()
@@ -63,28 +58,33 @@ async def get_weather(ctx: restate.Context, city: str):
6358
return result.output
6459
...
6560
```
66-
61+
6762
Args:
6863
context: The Restate context to use for the agent.
6964
auto_wrap_tool_calls: Whether to automatically wrap tool calls with restate's ctx.run() (durable step), True by default.
65+
7066
Returns:
7167
A RestateAgent instance that uses the provided context for its operations.
7268
"""
73-
get_context: Callable[[AgentDepsT], Context] = lambda _unused: context
69+
70+
def get_context(_unused: AgentDepsT) -> Context:
71+
return context
72+
7473
builder = self
7574
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=True)
7675

77-
def create_agent_with_advanced_tools(self, get_context: Callable[[AgentDepsT], Context]) -> AbstractAgent[AgentDepsT, OutputDataT]:
78-
"""
79-
Create an agent instance that is able to obtain Restate context from its dependencies.
80-
76+
def create_agent_with_advanced_tools(
77+
self, get_context: Callable[[AgentDepsT], Context]
78+
) -> AbstractAgent[AgentDepsT, OutputDataT]:
79+
"""Create an agent instance that is able to obtain Restate context from its dependencies.
80+
8181
Use this method, if you are planning to use restate's context inside the tools (for rpc, timers, multi step tools etc.)
8282
To obtain a context inside a tool you can add a dependency that has a `restate_context` attribute, and provide a `get_context` extractor
8383
function that extracts the context from the dependencies at runtime.
8484
8585
Note: that the agent will NOT automatically wrap tool calls with restate's `ctx.run()`
8686
since the tools may use the context in different ways.
87-
87+
8888
Example:
8989
```python
9090
...
@@ -93,9 +93,9 @@ def create_agent_with_advanced_tools(self, get_context: Callable[[AgentDepsT], C
9393
WeatherDeps:
9494
...
9595
restate_context: Context
96-
97-
weather_agent = Agent(..., deps_type=WeatherDeps, ...)
98-
96+
97+
weather_agent = Agent(..., deps_type=WeatherDeps, ...)
98+
9999
@weather_agent.tool
100100
async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
101101
restate_context = ctx.deps.restate_context
@@ -104,7 +104,7 @@ async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -
104104
return LatLng(lat, lng)
105105
106106
agent = RestateAgentProvider(weather_agent).create_agent_from_deps(lambda deps: deps.restate_context)
107-
107+
108108
weather = restate.Service('weather')
109109
110110
@weather.handler()
@@ -113,19 +113,21 @@ async def get_weather(ctx: restate.Context, city: str):
113113
return result.output
114114
...
115115
```
116-
116+
117117
Args:
118118
get_context: A callable that extracts the Restate context from the agent's dependencies.
119+
119120
Returns:
120121
A RestateAgent instance that uses the provided context extractor to obtain the Restate context at runtime.
121-
122+
122123
"""
123124
builder = self
124125
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=False)
125-
126+
126127

127128
class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
128129
"""An agent that integrates with the Restate framework for resilient applications."""
130+
129131
def __init__(
130132
self,
131133
builder: RestateAgentProvider[AgentDepsT, OutputDataT],
@@ -147,7 +149,8 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
147149
return RestateContextRunToolSet(toolset, context)
148150
try:
149151
from pydantic_ai.mcp import MCPServer
150-
from ._toolset import RestateMCPServer
152+
153+
from ._toolset import RestateMCPServer
151154
except ImportError:
152155
pass
153156
else:
@@ -163,7 +166,7 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
163166
self.sequential_tool_calls(),
164167
):
165168
yield
166-
169+
167170
@overload
168171
async def run(
169172
self,
@@ -253,8 +256,10 @@ async def main():
253256
The result of the run.
254257
"""
255258
if model is not None:
256-
raise TerminalError('An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.')
257-
context = self._get_context(deps)
259+
raise TerminalError(
260+
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
261+
)
262+
context = self._get_context(deps)
258263
with self._restate_overrides(context):
259264
return await super(WrapperAgent, self).run(
260265
user_prompt=user_prompt,
@@ -270,9 +275,3 @@ async def main():
270275
toolsets=toolsets,
271276
event_stream_handler=event_stream_handler,
272277
)
273-
274-
275-
276-
277-
278-
Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
1-
from typing import Any, Optional
1+
from typing import Any
2+
3+
from restate import Context, RunOptions
24

35
from pydantic_ai.durable_exec.restate._serde import PydanticTypeAdapter
46
from pydantic_ai.messages import ModelResponse
57
from pydantic_ai.models import Model
68
from pydantic_ai.models.wrapper import WrapperModel
79

8-
from restate import Context, RunOptions
910

1011
class RestateModelWrapper(WrapperModel):
11-
12-
def __init__(self,
13-
wrapped: Model,
14-
context: Context,
15-
max_attempts: Optional[int] = None):
12+
def __init__(self, wrapped: Model, context: Context, max_attempts: int | None = None):
1613
super().__init__(wrapped)
1714
self.options = RunOptions[ModelResponse](serde=PydanticTypeAdapter(ModelResponse), max_attempts=max_attempts)
1815
self.context = context
19-
16+
2017
async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
21-
return await self.context.run_typed("Model call", self.wrapped.request, self.options, *args, **kwargs)
22-
18+
return await self.context.run_typed('Model call', self.wrapped.request, self.options, *args, **kwargs)
Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
21
import typing
3-
from restate.serde import Serde
2+
43
from pydantic import TypeAdapter
4+
from restate.serde import Serde
55

66
T = typing.TypeVar('T')
77

8+
89
class PydanticTypeAdapter(Serde[T]):
910
"""A serializer/deserializer for Pydantic models."""
1011

11-
def __init__(self, model_type: typing.Type[T]):
12-
"""
13-
Initializes a new instance of the PydanticTypeAdaptorSerde class.
12+
def __init__(self, model_type: type[T]):
13+
"""Initializes a new instance of the PydanticTypeAdaptorSerde class.
1414
1515
Args:
1616
model_type (typing.Type[T]): The Pydantic model type to serialize/deserialize.
1717
"""
1818
self._model_type = TypeAdapter(model_type)
1919

20-
def deserialize(self, buf: bytes) -> typing.Optional[T]:
21-
"""
22-
Deserializes a bytearray to a Pydantic model.
20+
def deserialize(self, buf: bytes) -> T | None:
21+
"""Deserializes a bytearray to a Pydantic model.
2322
2423
Args:
2524
buf (bytearray): The bytearray to deserialize.
@@ -31,9 +30,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[T]:
3130
return None
3231
return self._model_type.validate_json(buf.decode('utf-8')) # raises if invalid
3332

34-
def serialize(self, obj: typing.Optional[T]) -> bytes:
35-
"""
36-
Serializes a Pydantic model to a bytearray.
33+
def serialize(self, obj: T | None) -> bytes:
34+
"""Serializes a Pydantic model to a bytearray.
3735
3836
Args:
3937
obj (typing.Optional[T]): The Pydantic model to serialize.
@@ -45,5 +43,3 @@ def serialize(self, obj: typing.Optional[T]) -> bytes:
4543
return b''
4644
tpe = TypeAdapter(type(obj))
4745
return tpe.dump_json(obj)
48-
49-

0 commit comments

Comments
 (0)