11from __future__ import annotations
22
33import warnings
4- from collections .abc import Sequence
4+ from collections .abc import AsyncIterator , Sequence
5+ from contextlib import AbstractAsyncContextManager
56from dataclasses import replace
67from typing import Any , Callable
78
89from pydantic .errors import PydanticUserError
9- from temporalio .client import ClientConfig , Plugin as ClientPlugin
10+ from temporalio .client import ClientConfig , Plugin as ClientPlugin , WorkflowHistory
1011from temporalio .contrib .pydantic import PydanticPayloadConverter , pydantic_data_converter
11- from temporalio .converter import DefaultPayloadConverter
12- from temporalio .worker import Plugin as WorkerPlugin , WorkerConfig
12+ from temporalio .converter import DataConverter , DefaultPayloadConverter
13+ from temporalio .service import ConnectConfig , ServiceClient
14+ from temporalio .worker import (
15+ Plugin as WorkerPlugin ,
16+ Replayer ,
17+ ReplayerConfig ,
18+ Worker ,
19+ WorkerConfig ,
20+ WorkflowReplayResult ,
21+ )
1322from temporalio .worker .workflow_sandbox import SandboxedWorkflowRunner
1423
1524from ...exceptions import UserError
3140class PydanticAIPlugin (ClientPlugin , WorkerPlugin ):
3241 """Temporal client and worker plugin for Pydantic AI."""
3342
34- def configure_client (self , config : ClientConfig ) -> ClientConfig :
35- if (data_converter := config .get ('data_converter' )) and data_converter .payload_converter_class not in (
36- DefaultPayloadConverter ,
37- PydanticPayloadConverter ,
38- ):
39- warnings .warn ( # pragma: no cover
40- 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
41- )
43+ def init_client_plugin (self , next : ClientPlugin ) -> None :
44+ self .next_client_plugin = next
4245
43- config ['data_converter' ] = pydantic_data_converter
44- return super ().configure_client (config )
46+ def init_worker_plugin (self , next : WorkerPlugin ) -> None :
47+ self .next_worker_plugin = next
48+
49+ def configure_client (self , config : ClientConfig ) -> ClientConfig :
50+ config ['data_converter' ] = self ._get_new_data_converter (config .get ('data_converter' ))
51+ return self .next_client_plugin .configure_client (config )
4552
4653 def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
4754 runner = config .get ('workflow_runner' ) # pyright: ignore[reportUnknownMemberType]
@@ -67,7 +74,35 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6774 PydanticUserError ,
6875 ]
6976
70- return super ().configure_worker (config )
77+ return self .next_worker_plugin .configure_worker (config )
78+
79+ async def connect_service_client (self , config : ConnectConfig ) -> ServiceClient :
80+ return await self .next_client_plugin .connect_service_client (config )
81+
82+ async def run_worker (self , worker : Worker ) -> None :
83+ await self .next_worker_plugin .run_worker (worker )
84+
85+ def configure_replayer (self , config : ReplayerConfig ) -> ReplayerConfig : # pragma: no cover
86+ config ['data_converter' ] = self ._get_new_data_converter (config .get ('data_converter' )) # pyright: ignore[reportUnknownMemberType]
87+ return self .next_worker_plugin .configure_replayer (config )
88+
89+ def run_replayer (
90+ self ,
91+ replayer : Replayer ,
92+ histories : AsyncIterator [WorkflowHistory ],
93+ ) -> AbstractAsyncContextManager [AsyncIterator [WorkflowReplayResult ]]: # pragma: no cover
94+ return self .next_worker_plugin .run_replayer (replayer , histories )
95+
96+ def _get_new_data_converter (self , converter : DataConverter | None ) -> DataConverter :
97+ if converter and converter .payload_converter_class not in (
98+ DefaultPayloadConverter ,
99+ PydanticPayloadConverter ,
100+ ):
101+ warnings .warn ( # pragma: no cover
102+ 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
103+ )
104+
105+ return pydantic_data_converter
71106
72107
73108class AgentPlugin (WorkerPlugin ):
@@ -76,8 +111,24 @@ class AgentPlugin(WorkerPlugin):
76111 def __init__ (self , agent : TemporalAgent [Any , Any ]):
77112 self .agent = agent
78113
114+ def init_worker_plugin (self , next : WorkerPlugin ) -> None :
115+ self .next_worker_plugin = next
116+
79117 def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
80118 activities : Sequence [Callable [..., Any ]] = config .get ('activities' , []) # pyright: ignore[reportUnknownMemberType]
81119 # Activities are checked for name conflicts by Temporal.
82120 config ['activities' ] = [* activities , * self .agent .temporal_activities ]
83- return super ().configure_worker (config )
121+ return self .next_worker_plugin .configure_worker (config )
122+
123+ async def run_worker (self , worker : Worker ) -> None :
124+ await self .next_worker_plugin .run_worker (worker )
125+
126+ def configure_replayer (self , config : ReplayerConfig ) -> ReplayerConfig : # pragma: no cover
127+ return self .next_worker_plugin .configure_replayer (config )
128+
129+ def run_replayer (
130+ self ,
131+ replayer : Replayer ,
132+ histories : AsyncIterator [WorkflowHistory ],
133+ ) -> AbstractAsyncContextManager [AsyncIterator [WorkflowReplayResult ]]: # pragma: no cover
134+ return self .next_worker_plugin .run_replayer (replayer , histories )
0 commit comments