11"""Invoking and supporting prompt preprocessor hook implementations."""
22
3+ from dataclasses import dataclass
34from traceback import format_tb
4-
55from typing import (
66 Any ,
77 Awaitable ,
88 Callable ,
9+ Generic ,
910 Iterable ,
1011 TypeAlias ,
1112 assert_never ,
3940from .common import (
4041 AsyncSessionPlugins ,
4142 HookController ,
43+ SendMessageCallback ,
4244 StatusBlockController ,
4345 TPluginConfigSchema ,
4446 TGlobalConfigSchema ,
@@ -125,6 +127,7 @@ def __init__(
125127 ) -> None :
126128 """Initialize prompt preprocessor hook controller."""
127129 super ().__init__ (session , request , plugin_config_schema , global_config_schema )
130+ self .task_id = request .task_id
128131 self .pci = request .pci
129132 self .token = request .token
130133
@@ -182,6 +185,124 @@ async def notify_done(self, message: str) -> None:
182185]
183186
184187
188+ # TODO: Define a common "PluginHookHandler" base class
189+ @dataclass ()
190+ class PromptPreprocessor (Generic [TPluginConfigSchema , TGlobalConfigSchema ]):
191+ """Handle accepting prompt preprocessing requests."""
192+
193+ plugin_name : str
194+ hook_impl : PromptPreprocessorHook
195+ plugin_config_schema : type [TPluginConfigSchema ]
196+ global_config_schema : type [TGlobalConfigSchema ]
197+
198+ def __post_init__ (self ) -> None :
199+ self ._logger = logger = new_logger (__name__ )
200+ logger .update_context (plugin_name = self .plugin_name )
201+
202+ async def process_requests (
203+ self , session : AsyncSessionPlugins , notify_ready : Callable [[], Any ]
204+ ) -> None :
205+ logger = self ._logger
206+ endpoint = PromptPreprocessingEndpoint ()
207+ async with session ._create_channel (endpoint ) as channel :
208+ notify_ready ()
209+ logger .info ("Opened channel to receive prompt preprocessing requests..." )
210+ send_cb = channel .send_message
211+ async with create_task_group () as tg :
212+ logger .debug ("Waiting for prompt preprocessing requests..." )
213+ async for contents in channel .rx_stream ():
214+ logger .debug (
215+ f"Handling prompt preprocessing channel message: { contents } "
216+ )
217+ for event in endpoint .iter_message_events (contents ):
218+ logger .debug ("Handling prompt preprocessing channel event" )
219+ endpoint .handle_rx_event (event )
220+ match event :
221+ case PromptPreprocessingRequestEvent ():
222+ logger .debug (
223+ "Running prompt preprocessing request hook"
224+ )
225+ ctl = PromptPreprocessorController (
226+ session ,
227+ event .arg ,
228+ self .plugin_config_schema ,
229+ self .global_config_schema ,
230+ )
231+ tg .start_soon (self ._invoke_hook , ctl , send_cb )
232+ if endpoint .is_finished :
233+ break
234+
235+ async def _invoke_hook (
236+ self ,
237+ ctl : PromptPreprocessorController [TPluginConfigSchema , TGlobalConfigSchema ],
238+ send_response : SendMessageCallback ,
239+ ) -> None :
240+ logger = self ._logger
241+ request = ctl .request
242+ message = request .input
243+ expected_cls = UserMessage
244+ if not isinstance (message , expected_cls ):
245+ logger .error (
246+ f"Received { type (message ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
247+ )
248+ return
249+ error_details : SerializedLMSExtendedErrorDict | None = None
250+ response_dict : UserMessageDict
251+ try :
252+ response = await self .hook_impl (ctl , message )
253+ except Exception as exc :
254+ err_msg = "Error calling prompt preprocessing hook"
255+ logger .error (err_msg , exc_info = True , exc = repr (exc ))
256+ # TODO: Determine if it's worth sending the stack trace to the server
257+ ui_cause = f"{ err_msg } \n ({ type (exc ).__name__ } : { exc } )"
258+ error_details = SerializedLMSExtendedErrorDict (
259+ cause = ui_cause , stack = "\n " .join (format_tb (exc .__traceback__ ))
260+ )
261+ else :
262+ if response is None :
263+ # No change to message
264+ response_dict = message .to_dict ()
265+ else :
266+ logger .debug (
267+ "Validating prompt preprocessing response" , response = response
268+ )
269+ if isinstance (response , dict ):
270+ try :
271+ parsed_response = load_struct (response , expected_cls )
272+ except ValidationError as exc :
273+ err_msg = f"Failed to parse prompt preprocessing response as { expected_cls .__name__ } \n ({ exc } )"
274+ logger .error (err_msg )
275+ error_details = SerializedLMSExtendedErrorDict (cause = err_msg )
276+ else :
277+ response_dict = parsed_response .to_dict ()
278+ elif isinstance (response , UserMessage ):
279+ response_dict = response .to_dict ()
280+ else :
281+ err_msg = f"Prompt preprocessing hook returned { type (response ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
282+ logger .error (err_msg )
283+ error_details = SerializedLMSExtendedErrorDict (cause = err_msg )
284+ channel_message : DictObject
285+ if error_details is not None :
286+ error_title = f"Prompt preprocessing error in plugin { self .plugin_name !r} "
287+ common_error_args : SerializedLMSExtendedErrorDict = {
288+ "title" : error_title ,
289+ "rootTitle" : error_title ,
290+ }
291+ error_details .update (common_error_args )
292+ channel_message = PromptPreprocessingErrorDict (
293+ type = "error" ,
294+ taskId = request .task_id ,
295+ error = error_details ,
296+ )
297+ else :
298+ channel_message = PromptPreprocessingCompleteDict (
299+ type = "complete" ,
300+ taskId = request .task_id ,
301+ processed = response_dict ,
302+ )
303+ await send_response (channel_message )
304+
305+
185306async def run_prompt_preprocessor (
186307 plugin_name : str ,
187308 hook_impl : PromptPreprocessorHook ,
@@ -191,94 +312,7 @@ async def run_prompt_preprocessor(
191312 notify_ready : Callable [[], Any ],
192313) -> None :
193314 """Accept prompt preprocessing requests."""
194- logger = new_logger (__name__ )
195- logger .update_context (plugin_name = plugin_name )
196- endpoint = PromptPreprocessingEndpoint ()
197- async with session ._create_channel (endpoint ) as channel :
198- notify_ready ()
199- logger .info ("Opened channel to receive prompt preprocessing requests..." )
200-
201- async def _invoke_hook (request : PromptPreprocessingRequest ) -> None :
202- message = request .input
203- expected_cls = UserMessage
204- if not isinstance (message , expected_cls ):
205- logger .error (
206- f"Received { type (message ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
207- )
208- return
209- hook_controller = PromptPreprocessorController (
210- session , request , plugin_config_schema , global_config_schema
211- )
212- error_details : SerializedLMSExtendedErrorDict | None = None
213- response_dict : UserMessageDict
214- try :
215- response = await hook_impl (hook_controller , message )
216- except Exception as exc :
217- err_msg = "Error calling prompt preprocessing hook"
218- logger .error (err_msg , exc_info = True , exc = repr (exc ))
219- # TODO: Determine if it's worth sending the stack trace to the server
220- ui_cause = f"{ err_msg } \n ({ type (exc ).__name__ } : { exc } )"
221- error_details = SerializedLMSExtendedErrorDict (
222- cause = ui_cause , stack = "\n " .join (format_tb (exc .__traceback__ ))
223- )
224- else :
225- if response is None :
226- # No change to message
227- response_dict = message .to_dict ()
228- else :
229- logger .debug (
230- "Validating prompt preprocessing response" , response = response
231- )
232- if isinstance (response , dict ):
233- try :
234- parsed_response = load_struct (response , expected_cls )
235- except ValidationError as exc :
236- err_msg = f"Failed to parse prompt preprocessing response as { expected_cls .__name__ } \n ({ exc } )"
237- logger .error (err_msg )
238- error_details = SerializedLMSExtendedErrorDict (
239- cause = err_msg
240- )
241- else :
242- response_dict = parsed_response .to_dict ()
243- elif isinstance (response , UserMessage ):
244- response_dict = response .to_dict ()
245- else :
246- err_msg = f"Prompt preprocessing hook returned { type (response ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
247- logger .error (err_msg )
248- error_details = SerializedLMSExtendedErrorDict (cause = err_msg )
249- channel_message : DictObject
250- if error_details is not None :
251- error_title = f"Prompt preprocessing error in plugin { plugin_name !r} "
252- common_error_args : SerializedLMSExtendedErrorDict = {
253- "title" : error_title ,
254- "rootTitle" : error_title ,
255- }
256- error_details .update (common_error_args )
257- channel_message = PromptPreprocessingErrorDict (
258- type = "error" ,
259- taskId = request .task_id ,
260- error = error_details ,
261- )
262- else :
263- channel_message = PromptPreprocessingCompleteDict (
264- type = "complete" ,
265- taskId = request .task_id ,
266- processed = response_dict ,
267- )
268- await channel .send_message (channel_message )
269-
270- async with create_task_group () as tg :
271- logger .debug ("Waiting for prompt preprocessing requests..." )
272- async for contents in channel .rx_stream ():
273- logger .debug (
274- f"Handling prompt preprocessing channel message: { contents } "
275- )
276- for event in endpoint .iter_message_events (contents ):
277- logger .debug ("Handling prompt preprocessing channel event" )
278- endpoint .handle_rx_event (event )
279- match event :
280- case PromptPreprocessingRequestEvent ():
281- logger .debug ("Running prompt preprocessing request hook" )
282- tg .start_soon (_invoke_hook , event .arg )
283- if endpoint .is_finished :
284- break
315+ prompt_preprocessor = PromptPreprocessor (
316+ plugin_name , hook_impl , plugin_config_schema , global_config_schema
317+ )
318+ await prompt_preprocessor .process_requests (session , notify_ready )
0 commit comments