Skip to content

Commit 46823ab

Browse files
authored
Always run event_stream_handler inside Temporal activity (#2806)
1 parent 76ce1bc commit 46823ab

File tree

2 files changed

+154
-32
lines changed

2 files changed

+154
-32
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
3+
from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence
44
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
55
from contextvars import ContextVar
6+
from dataclasses import dataclass
67
from datetime import timedelta
78
from typing import Any, Literal, overload
89

10+
from pydantic import ConfigDict, with_config
911
from pydantic.errors import PydanticUserError
1012
from pydantic_core import PydanticSerializationError
11-
from temporalio import workflow
13+
from temporalio import activity, workflow
1214
from temporalio.common import RetryPolicy
1315
from temporalio.workflow import ActivityConfig
1416
from typing_extensions import Never
@@ -21,23 +23,31 @@
2123
)
2224
from pydantic_ai._run_context import AgentDepsT
2325
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
24-
from pydantic_ai.durable_exec.temporal._run_context import TemporalRunContext
2526
from pydantic_ai.exceptions import UserError
2627
from pydantic_ai.models import Model
2728
from pydantic_ai.output import OutputDataT, OutputSpec
2829
from pydantic_ai.result import StreamedRunResult
2930
from pydantic_ai.settings import ModelSettings
3031
from pydantic_ai.tools import (
3132
DeferredToolResults,
33+
RunContext,
3234
Tool,
3335
ToolFuncEither,
3436
)
3537
from pydantic_ai.toolsets import AbstractToolset
3638

3739
from ._model import TemporalModel
40+
from ._run_context import TemporalRunContext
3841
from ._toolset import TemporalWrapperToolset, temporalize_toolset
3942

4043

44+
@dataclass
45+
@with_config(ConfigDict(arbitrary_types_allowed=True))
46+
class _EventStreamHandlerParams:
47+
event: _messages.AgentStreamEvent
48+
serialized_run_context: Any
49+
50+
4151
class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
4252
def __init__(
4353
self,
@@ -86,6 +96,10 @@ def __init__(
8696
"""
8797
super().__init__(wrapped)
8898

99+
self._name = name
100+
self._event_stream_handler = event_stream_handler
101+
self.run_context_type = run_context_type
102+
89103
# start_to_close_timeout is required
90104
activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(seconds=60))
91105

@@ -97,13 +111,13 @@ def __init__(
97111
PydanticUserError.__name__,
98112
]
99113
activity_config['retry_policy'] = retry_policy
114+
self.activity_config = activity_config
100115

101116
model_activity_config = model_activity_config or {}
102117
toolset_activity_config = toolset_activity_config or {}
103118
tool_activity_config = tool_activity_config or {}
104119

105-
self._name = name or wrapped.name
106-
if self._name is None:
120+
if self.name is None:
107121
raise UserError(
108122
"An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
109123
)
@@ -116,13 +130,33 @@ def __init__(
116130
'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
117131
)
118132

133+
async def event_stream_handler_activity(params: _EventStreamHandlerParams, deps: AgentDepsT) -> None:
134+
# We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
135+
# and that only ends up calling `event_stream_handler` if it is set.
136+
assert self.event_stream_handler is not None
137+
138+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
139+
140+
async def streamed_response():
141+
yield params.event
142+
143+
await self.event_stream_handler(run_context, streamed_response())
144+
145+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
146+
event_stream_handler_activity.__annotations__['deps'] = self.deps_type
147+
148+
self.event_stream_handler_activity = activity.defn(name=f'{activity_name_prefix}__event_stream_handler')(
149+
event_stream_handler_activity
150+
)
151+
activities.append(self.event_stream_handler_activity)
152+
119153
temporal_model = TemporalModel(
120154
wrapped.model,
121155
activity_name_prefix=activity_name_prefix,
122156
activity_config=activity_config | model_activity_config,
123157
deps_type=self.deps_type,
124-
run_context_type=run_context_type,
125-
event_stream_handler=event_stream_handler or wrapped.event_stream_handler,
158+
run_context_type=self.run_context_type,
159+
event_stream_handler=self.event_stream_handler,
126160
)
127161
activities.extend(temporal_model.temporal_activities)
128162

@@ -139,7 +173,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
139173
activity_config | toolset_activity_config.get(id, {}),
140174
tool_activity_config.get(id, {}),
141175
self.deps_type,
142-
run_context_type,
176+
self.run_context_type,
143177
)
144178
if isinstance(toolset, TemporalWrapperToolset):
145179
activities.extend(toolset.temporal_activities)
@@ -155,7 +189,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
155189

156190
@property
157191
def name(self) -> str | None:
158-
return self._name
192+
return self._name or super().name
159193

160194
@name.setter
161195
def name(self, value: str | None) -> None: # pragma: no cover
@@ -167,6 +201,33 @@ def name(self, value: str | None) -> None: # pragma: no cover
167201
def model(self) -> Model:
168202
return self._model
169203

204+
@property
205+
def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
206+
handler = self._event_stream_handler or super().event_stream_handler
207+
if handler is None:
208+
return None
209+
elif workflow.in_workflow():
210+
return self._call_event_stream_handler_activity
211+
else:
212+
return handler
213+
214+
async def _call_event_stream_handler_activity(
215+
self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent]
216+
) -> None:
217+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
218+
async for event in stream:
219+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
220+
activity=self.event_stream_handler_activity,
221+
args=[
222+
_EventStreamHandlerParams(
223+
event=event,
224+
serialized_run_context=serialized_run_context,
225+
),
226+
ctx.deps,
227+
],
228+
**self.activity_config,
229+
)
230+
170231
@property
171232
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
172233
with self._temporal_overrides():
@@ -296,7 +357,7 @@ async def main():
296357
usage=usage,
297358
infer_name=infer_name,
298359
toolsets=toolsets,
299-
event_stream_handler=event_stream_handler,
360+
event_stream_handler=event_stream_handler or self.event_stream_handler,
300361
**_deprecated_kwargs,
301362
)
302363

tests/test_temporal.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -375,25 +375,56 @@ async def test_complex_agent_run_in_workflow(
375375
],
376376
)
377377
],
378-
),
379-
BasicSpan(content='ctx.run_step=1'),
378+
)
380379
],
381380
),
382-
BasicSpan(content='ctx.run_step=1'),
383381
BasicSpan(
384-
content='{"part":{"tool_name":"get_country","args":"{}","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
382+
content='StartActivity:agent__complex_agent__event_stream_handler',
383+
children=[
384+
BasicSpan(
385+
content='RunActivity:agent__complex_agent__event_stream_handler',
386+
children=[
387+
BasicSpan(content='ctx.run_step=1'),
388+
BasicSpan(
389+
content='{"part":{"tool_name":"get_country","args":"{}","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
390+
),
391+
],
392+
)
393+
],
385394
),
386395
BasicSpan(
387-
content='{"part":{"tool_name":"get_product_name","args":"{}","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
396+
content='StartActivity:agent__complex_agent__event_stream_handler',
397+
children=[
398+
BasicSpan(
399+
content='RunActivity:agent__complex_agent__event_stream_handler',
400+
children=[
401+
BasicSpan(content='ctx.run_step=1'),
402+
BasicSpan(
403+
content='{"part":{"tool_name":"get_product_name","args":"{}","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
404+
),
405+
],
406+
)
407+
],
388408
),
389409
BasicSpan(
390410
content='running 2 tools',
391411
children=[
392412
BasicSpan(content='running tool: get_country'),
393413
BasicSpan(
394-
content=IsStr(
395-
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
396-
)
414+
content='StartActivity:agent__complex_agent__event_stream_handler',
415+
children=[
416+
BasicSpan(
417+
content='RunActivity:agent__complex_agent__event_stream_handler',
418+
children=[
419+
BasicSpan(content='ctx.run_step=1'),
420+
BasicSpan(
421+
content=IsStr(
422+
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
423+
)
424+
),
425+
],
426+
)
427+
],
397428
),
398429
BasicSpan(
399430
content='running tool: get_product_name',
@@ -409,9 +440,20 @@ async def test_complex_agent_run_in_workflow(
409440
],
410441
),
411442
BasicSpan(
412-
content=IsStr(
413-
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
414-
)
443+
content='StartActivity:agent__complex_agent__event_stream_handler',
444+
children=[
445+
BasicSpan(
446+
content='RunActivity:agent__complex_agent__event_stream_handler',
447+
children=[
448+
BasicSpan(content='ctx.run_step=1'),
449+
BasicSpan(
450+
content=IsStr(
451+
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
452+
)
453+
),
454+
],
455+
)
456+
],
415457
),
416458
],
417459
),
@@ -455,13 +497,22 @@ async def test_complex_agent_run_in_workflow(
455497
],
456498
)
457499
],
458-
),
459-
BasicSpan(content='ctx.run_step=2'),
500+
)
460501
],
461502
),
462-
BasicSpan(content='ctx.run_step=2'),
463503
BasicSpan(
464-
content='{"part":{"tool_name":"get_weather","args":"{\\"city\\":\\"Mexico City\\"}","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
504+
content='StartActivity:agent__complex_agent__event_stream_handler',
505+
children=[
506+
BasicSpan(
507+
content='RunActivity:agent__complex_agent__event_stream_handler',
508+
children=[
509+
BasicSpan(content='ctx.run_step=2'),
510+
BasicSpan(
511+
content='{"part":{"tool_name":"get_weather","args":"{\\"city\\":\\"Mexico City\\"}","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_kind":"tool-call"},"event_kind":"function_tool_call"}'
512+
),
513+
],
514+
)
515+
],
465516
),
466517
BasicSpan(
467518
content='running 1 tool',
@@ -480,9 +531,20 @@ async def test_complex_agent_run_in_workflow(
480531
],
481532
),
482533
BasicSpan(
483-
content=IsStr(
484-
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
485-
)
534+
content='StartActivity:agent__complex_agent__event_stream_handler',
535+
children=[
536+
BasicSpan(
537+
content='RunActivity:agent__complex_agent__event_stream_handler',
538+
children=[
539+
BasicSpan(content='ctx.run_step=2'),
540+
BasicSpan(
541+
content=IsStr(
542+
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
543+
)
544+
),
545+
],
546+
)
547+
],
486548
),
487549
],
488550
),
@@ -631,11 +693,9 @@ async def test_complex_agent_run_in_workflow(
631693
],
632694
)
633695
],
634-
),
635-
BasicSpan(content='ctx.run_step=3'),
696+
)
636697
],
637698
),
638-
BasicSpan(content='ctx.run_step=3'),
639699
],
640700
),
641701
BasicSpan(content='CompleteWorkflow:ComplexAgentWorkflow'),
@@ -946,7 +1006,7 @@ async def test_multiple_agents(allow_model_requests: None, client: Client):
9461006

9471007

9481008
async def test_agent_name_collision(allow_model_requests: None, client: Client):
949-
with pytest.raises(ValueError, match='More than one activity named agent__simple_agent__model_request'):
1009+
with pytest.raises(ValueError, match='More than one activity named agent__simple_agent__event_stream_handler'):
9501010
async with Worker(
9511011
client,
9521012
task_queue=TASK_QUEUE,
@@ -1022,6 +1082,7 @@ async def test_temporal_agent():
10221082
for activity in complex_temporal_agent.temporal_activities
10231083
] == snapshot(
10241084
[
1085+
'agent__complex_agent__event_stream_handler',
10251086
'agent__complex_agent__model_request',
10261087
'agent__complex_agent__model_request_stream',
10271088
'agent__complex_agent__toolset__<agent>__call_tool',

0 commit comments

Comments
 (0)