1111 Iterable ,
1212 TypeAlias ,
1313 assert_never ,
14- get_args as get_type_args ,
1514)
1615
1716from anyio import create_task_group
1817
1918from ..._logging import get_logger
2019from ...schemas import DictObject , EmptyDict , ValidationError
21- from ...history import AnyChatMessage , AnyChatMessageDict
20+ from ...history import UserMessage , UserMessageDict
2221from ...json_api import (
2322 ChannelCommonRxEvent ,
2423 ChannelEndpoint ,
@@ -182,8 +181,8 @@ async def notify_done(self, message: str) -> None:
182181
183182
184183PromptPreprocessorHook = Callable [
185- [PromptPreprocessorController [Any , Any ], AnyChatMessage ],
186- Awaitable [AnyChatMessage | AnyChatMessageDict | None ],
184+ [PromptPreprocessorController [Any , Any ], UserMessage ],
185+ Awaitable [UserMessage | UserMessageDict | None ],
187186]
188187
189188
@@ -203,11 +202,17 @@ async def run_prompt_preprocessor(
203202
204203 async def _invoke_hook (request : PromptPreprocessingRequest ) -> None :
205204 message = request .input
205+ expected_cls = UserMessage
206+ if not isinstance (message , expected_cls ):
207+ logger .error (
208+ f"Received { type (message ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
209+ )
210+ return
206211 hook_controller = PromptPreprocessorController (
207212 session , request , plugin_config_schema , global_config_schema
208213 )
209214 error_details : SerializedLMSExtendedErrorDict | None = None
210- response_dict : AnyChatMessageDict
215+ response_dict : UserMessageDict
211216 try :
212217 response = await hook_impl (hook_controller , message )
213218 except asyncio .CancelledError :
@@ -227,27 +232,23 @@ async def _invoke_hook(request: PromptPreprocessingRequest) -> None:
227232 response_dict = message .to_dict ()
228233 else :
229234 logger .debug (
230- f "Validating prompt preprocessing response: { response !r } "
235+ "Validating prompt preprocessing response" , response = response
231236 )
232- response_cls = type (message )
233237 if isinstance (response , dict ):
234- # Parse the response to ensure validity client side,
235- # otherwise serialising the message may fail and crash the plugin
236238 try :
237- # Response should have the same role as the received message
238- parsed_response = load_struct (response , response_cls )
239+ parsed_response = load_struct (response , expected_cls )
239240 except ValidationError as exc :
240- err_msg = f"Failed to parse prompt preprocessing response as { response_cls .__name__ } "
241+ err_msg = f"Failed to parse prompt preprocessing response as { expected_cls .__name__ } "
241242 logger .error (err_msg , exc_info = True , exc = repr (exc ))
242243 error_details = SerializedLMSExtendedErrorDict (
243244 cause = err_msg
244245 )
245246 else :
246247 response_dict = parsed_response .to_dict ()
247- elif isinstance (response , get_type_args ( AnyChatMessage ) ):
248+ elif isinstance (response , UserMessage ):
248249 response_dict = response .to_dict ()
249250 else :
250- err_msg = f"Prompt preprocessing hook returned { type (response ).__name__ !r} ({ response_cls .__name__ !r} expected)"
251+ err_msg = f"Prompt preprocessing hook returned { type (response ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
251252 logger .error (err_msg )
252253 error_details = SerializedLMSExtendedErrorDict (cause = err_msg )
253254 channel_message : DictObject
0 commit comments