1
1
from __future__ import annotations
2
2
3
- from collections .abc import AsyncIterator , Callable , Iterator , Sequence
3
+ from collections .abc import AsyncIterable , AsyncIterator , Callable , Iterator , Sequence
4
4
from contextlib import AbstractAsyncContextManager , asynccontextmanager , contextmanager
5
5
from contextvars import ContextVar
6
+ from dataclasses import dataclass
6
7
from datetime import timedelta
7
8
from typing import Any , Literal , overload
8
9
10
+ from pydantic import ConfigDict , with_config
9
11
from pydantic .errors import PydanticUserError
10
12
from pydantic_core import PydanticSerializationError
11
- from temporalio import workflow
13
+ from temporalio import activity , workflow
12
14
from temporalio .common import RetryPolicy
13
15
from temporalio .workflow import ActivityConfig
14
16
from typing_extensions import Never
21
23
)
22
24
from pydantic_ai ._run_context import AgentDepsT
23
25
from pydantic_ai .agent import AbstractAgent , AgentRun , AgentRunResult , EventStreamHandler , RunOutputDataT , WrapperAgent
24
- from pydantic_ai .durable_exec .temporal ._run_context import TemporalRunContext
25
26
from pydantic_ai .exceptions import UserError
26
27
from pydantic_ai .models import Model
27
28
from pydantic_ai .output import OutputDataT , OutputSpec
28
29
from pydantic_ai .result import StreamedRunResult
29
30
from pydantic_ai .settings import ModelSettings
30
31
from pydantic_ai .tools import (
31
32
DeferredToolResults ,
33
+ RunContext ,
32
34
Tool ,
33
35
ToolFuncEither ,
34
36
)
35
37
from pydantic_ai .toolsets import AbstractToolset
36
38
37
39
from ._model import TemporalModel
40
+ from ._run_context import TemporalRunContext
38
41
from ._toolset import TemporalWrapperToolset , temporalize_toolset
39
42
40
43
44
+ @dataclass
45
+ @with_config (ConfigDict (arbitrary_types_allowed = True ))
46
+ class _EventStreamHandlerParams :
47
+ event : _messages .AgentStreamEvent
48
+ serialized_run_context : Any
49
+
50
+
41
51
class TemporalAgent (WrapperAgent [AgentDepsT , OutputDataT ]):
42
52
def __init__ (
43
53
self ,
@@ -86,6 +96,10 @@ def __init__(
86
96
"""
87
97
super ().__init__ (wrapped )
88
98
99
+ self ._name = name
100
+ self ._event_stream_handler = event_stream_handler
101
+ self .run_context_type = run_context_type
102
+
89
103
# start_to_close_timeout is required
90
104
activity_config = activity_config or ActivityConfig (start_to_close_timeout = timedelta (seconds = 60 ))
91
105
@@ -97,13 +111,13 @@ def __init__(
97
111
PydanticUserError .__name__ ,
98
112
]
99
113
activity_config ['retry_policy' ] = retry_policy
114
+ self .activity_config = activity_config
100
115
101
116
model_activity_config = model_activity_config or {}
102
117
toolset_activity_config = toolset_activity_config or {}
103
118
tool_activity_config = tool_activity_config or {}
104
119
105
- self ._name = name or wrapped .name
106
- if self ._name is None :
120
+ if self .name is None :
107
121
raise UserError (
108
122
"An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
109
123
)
@@ -116,13 +130,33 @@ def __init__(
116
130
'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
117
131
)
118
132
133
+ async def event_stream_handler_activity (params : _EventStreamHandlerParams , deps : AgentDepsT ) -> None :
134
+ # We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
135
+ # and that only ends up calling `event_stream_handler` if it is set.
136
+ assert self .event_stream_handler is not None
137
+
138
+ run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
139
+
140
+ async def streamed_response ():
141
+ yield params .event
142
+
143
+ await self .event_stream_handler (run_context , streamed_response ())
144
+
145
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
146
+ event_stream_handler_activity .__annotations__ ['deps' ] = self .deps_type
147
+
148
+ self .event_stream_handler_activity = activity .defn (name = f'{ activity_name_prefix } __event_stream_handler' )(
149
+ event_stream_handler_activity
150
+ )
151
+ activities .append (self .event_stream_handler_activity )
152
+
119
153
temporal_model = TemporalModel (
120
154
wrapped .model ,
121
155
activity_name_prefix = activity_name_prefix ,
122
156
activity_config = activity_config | model_activity_config ,
123
157
deps_type = self .deps_type ,
124
- run_context_type = run_context_type ,
125
- event_stream_handler = event_stream_handler or wrapped .event_stream_handler ,
158
+ run_context_type = self . run_context_type ,
159
+ event_stream_handler = self .event_stream_handler ,
126
160
)
127
161
activities .extend (temporal_model .temporal_activities )
128
162
@@ -139,7 +173,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
139
173
activity_config | toolset_activity_config .get (id , {}),
140
174
tool_activity_config .get (id , {}),
141
175
self .deps_type ,
142
- run_context_type ,
176
+ self . run_context_type ,
143
177
)
144
178
if isinstance (toolset , TemporalWrapperToolset ):
145
179
activities .extend (toolset .temporal_activities )
@@ -155,7 +189,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
155
189
156
190
@property
157
191
def name (self ) -> str | None :
158
- return self ._name
192
+ return self ._name or super (). name
159
193
160
194
@name .setter
161
195
def name (self , value : str | None ) -> None : # pragma: no cover
@@ -167,6 +201,33 @@ def name(self, value: str | None) -> None: # pragma: no cover
167
201
def model (self ) -> Model :
168
202
return self ._model
169
203
204
+ @property
205
+ def event_stream_handler (self ) -> EventStreamHandler [AgentDepsT ] | None :
206
+ handler = self ._event_stream_handler or super ().event_stream_handler
207
+ if handler is None :
208
+ return None
209
+ elif workflow .in_workflow ():
210
+ return self ._call_event_stream_handler_activity
211
+ else :
212
+ return handler
213
+
214
+ async def _call_event_stream_handler_activity (
215
+ self , ctx : RunContext [AgentDepsT ], stream : AsyncIterable [_messages .AgentStreamEvent ]
216
+ ) -> None :
217
+ serialized_run_context = self .run_context_type .serialize_run_context (ctx )
218
+ async for event in stream :
219
+ await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
220
+ activity = self .event_stream_handler_activity ,
221
+ args = [
222
+ _EventStreamHandlerParams (
223
+ event = event ,
224
+ serialized_run_context = serialized_run_context ,
225
+ ),
226
+ ctx .deps ,
227
+ ],
228
+ ** self .activity_config ,
229
+ )
230
+
170
231
@property
171
232
def toolsets (self ) -> Sequence [AbstractToolset [AgentDepsT ]]:
172
233
with self ._temporal_overrides ():
@@ -296,7 +357,7 @@ async def main():
296
357
usage = usage ,
297
358
infer_name = infer_name ,
298
359
toolsets = toolsets ,
299
- event_stream_handler = event_stream_handler ,
360
+ event_stream_handler = event_stream_handler or self . event_stream_handler ,
300
361
** _deprecated_kwargs ,
301
362
)
302
363
0 commit comments