|
4 | 4 | from contextlib import contextmanager |
5 | 5 | from typing import Any, Never, overload |
6 | 6 | from typing_extensions import Generic |
7 | | -from dataclasses import dataclass |
8 | 7 |
|
9 | 8 | from pydantic_ai import models |
10 | 9 | from pydantic_ai._run_context import AgentDepsT |
@@ -38,6 +37,7 @@ def __init__(self, |
38 | 37 | # here we collect all the configuration that will be passed to the RestateAgent |
39 | 38 | # the actual context will be provided at runtime. |
40 | 39 | self.wrapped = wrapped |
| 40 | + self.model = wrapped.model |
41 | 41 | self.max_attempts = max_attempts |
42 | 42 |
|
43 | 43 | def create_agent(self, context: Context) -> AbstractAgent[AgentDepsT, OutputDataT]: |
@@ -70,7 +70,7 @@ async def get_weather(ctx: restate.Context, city: str): |
70 | 70 | Returns: |
71 | 71 | A RestateAgent instance that uses the provided context for its operations. |
72 | 72 | """ |
73 | | - get_context = lambda _unused: context |
| 73 | + get_context: Callable[[AgentDepsT], Context] = lambda _unused: context |
74 | 74 | builder = self |
75 | 75 | return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=True) |
76 | 76 |
|
@@ -139,7 +139,7 @@ def __init__( |
139 | 139 |
|
140 | 140 | @contextmanager |
141 | 141 | def _restate_overrides(self, context: Context) -> Iterator[None]: |
142 | | - model = RestateModelWrapper(self._builder.wrapped.model, context, max_attempts=self._builder.max_attempts) |
| 142 | + model = RestateModelWrapper(self._builder.model, context, max_attempts=self._builder.max_attempts) |
143 | 143 |
|
144 | 144 | def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: |
145 | 145 | """Set the Restate context for the toolset, wrapping tools if needed.""" |
|
0 commit comments