1
1
from __future__ import annotations
2
2
3
3
import warnings
4
- from collections .abc import Sequence
4
+ from collections .abc import AsyncIterator , Sequence
5
+ from contextlib import AbstractAsyncContextManager
5
6
from dataclasses import replace
6
7
from typing import Any , Callable
7
8
8
9
from pydantic .errors import PydanticUserError
9
- from temporalio .client import ClientConfig , Plugin as ClientPlugin
10
+ from temporalio .client import ClientConfig , Plugin as ClientPlugin , WorkflowHistory
10
11
from temporalio .contrib .pydantic import PydanticPayloadConverter , pydantic_data_converter
11
- from temporalio .converter import DefaultPayloadConverter
12
- from temporalio .worker import Plugin as WorkerPlugin , WorkerConfig
12
+ from temporalio .converter import DataConverter , DefaultPayloadConverter
13
+ from temporalio .service import ConnectConfig , ServiceClient
14
+ from temporalio .worker import (
15
+ Plugin as WorkerPlugin ,
16
+ Replayer ,
17
+ ReplayerConfig ,
18
+ Worker ,
19
+ WorkerConfig ,
20
+ WorkflowReplayResult ,
21
+ )
13
22
from temporalio .worker .workflow_sandbox import SandboxedWorkflowRunner
14
23
15
24
from ...exceptions import UserError
31
40
class PydanticAIPlugin (ClientPlugin , WorkerPlugin ):
32
41
"""Temporal client and worker plugin for Pydantic AI."""
33
42
34
- def configure_client (self , config : ClientConfig ) -> ClientConfig :
35
- if (data_converter := config .get ('data_converter' )) and data_converter .payload_converter_class not in (
36
- DefaultPayloadConverter ,
37
- PydanticPayloadConverter ,
38
- ):
39
- warnings .warn ( # pragma: no cover
40
- 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
41
- )
43
+ def init_client_plugin (self , next : ClientPlugin ) -> None :
44
+ self .next_client_plugin = next
42
45
43
- config ['data_converter' ] = pydantic_data_converter
44
- return super ().configure_client (config )
46
+ def init_worker_plugin (self , next : WorkerPlugin ) -> None :
47
+ self .next_worker_plugin = next
48
+
49
+ def configure_client (self , config : ClientConfig ) -> ClientConfig :
50
+ config ['data_converter' ] = self ._get_new_data_converter (config .get ('data_converter' ))
51
+ return self .next_client_plugin .configure_client (config )
45
52
46
53
def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
47
54
runner = config .get ('workflow_runner' ) # pyright: ignore[reportUnknownMemberType]
@@ -67,7 +74,35 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
67
74
PydanticUserError ,
68
75
]
69
76
70
- return super ().configure_worker (config )
77
+ return self .next_worker_plugin .configure_worker (config )
78
+
79
+ async def connect_service_client (self , config : ConnectConfig ) -> ServiceClient :
80
+ return await self .next_client_plugin .connect_service_client (config )
81
+
82
+ async def run_worker (self , worker : Worker ) -> None :
83
+ await self .next_worker_plugin .run_worker (worker )
84
+
85
+ def configure_replayer (self , config : ReplayerConfig ) -> ReplayerConfig : # pragma: no cover
86
+ config ['data_converter' ] = self ._get_new_data_converter (config .get ('data_converter' )) # pyright: ignore[reportUnknownMemberType]
87
+ return self .next_worker_plugin .configure_replayer (config )
88
+
89
+ def run_replayer (
90
+ self ,
91
+ replayer : Replayer ,
92
+ histories : AsyncIterator [WorkflowHistory ],
93
+ ) -> AbstractAsyncContextManager [AsyncIterator [WorkflowReplayResult ]]: # pragma: no cover
94
+ return self .next_worker_plugin .run_replayer (replayer , histories )
95
+
96
+ def _get_new_data_converter (self , converter : DataConverter | None ) -> DataConverter :
97
+ if converter and converter .payload_converter_class not in (
98
+ DefaultPayloadConverter ,
99
+ PydanticPayloadConverter ,
100
+ ):
101
+ warnings .warn ( # pragma: no cover
102
+ 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
103
+ )
104
+
105
+ return pydantic_data_converter
71
106
72
107
73
108
class AgentPlugin (WorkerPlugin ):
@@ -76,8 +111,24 @@ class AgentPlugin(WorkerPlugin):
76
111
def __init__ (self , agent : TemporalAgent [Any , Any ]):
77
112
self .agent = agent
78
113
114
+ def init_worker_plugin (self , next : WorkerPlugin ) -> None :
115
+ self .next_worker_plugin = next
116
+
79
117
def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
80
118
activities : Sequence [Callable [..., Any ]] = config .get ('activities' , []) # pyright: ignore[reportUnknownMemberType]
81
119
# Activities are checked for name conflicts by Temporal.
82
120
config ['activities' ] = [* activities , * self .agent .temporal_activities ]
83
- return super ().configure_worker (config )
121
+ return self .next_worker_plugin .configure_worker (config )
122
+
123
+ async def run_worker (self , worker : Worker ) -> None :
124
+ await self .next_worker_plugin .run_worker (worker )
125
+
126
+ def configure_replayer (self , config : ReplayerConfig ) -> ReplayerConfig : # pragma: no cover
127
+ return self .next_worker_plugin .configure_replayer (config )
128
+
129
+ def run_replayer (
130
+ self ,
131
+ replayer : Replayer ,
132
+ histories : AsyncIterator [WorkflowHistory ],
133
+ ) -> AbstractAsyncContextManager [AsyncIterator [WorkflowReplayResult ]]: # pragma: no cover
134
+ return self .next_worker_plugin .run_replayer (replayer , histories )
0 commit comments