11"""Invoking and supporting prompt preprocessor hook implementations."""
22
3+ import asyncio
4+
5+ from contextlib import asynccontextmanager
36from dataclasses import dataclass
47from traceback import format_tb
58from typing import (
69 Any ,
10+ AsyncIterator ,
711 Awaitable ,
812 Callable ,
913 Generic ,
1317)
1418
1519from anyio import create_task_group
20+ from anyio .abc import TaskGroup
1621
1722from ..._logging import new_logger
1823from ...schemas import DictObject , EmptyDict , ValidationError
2934 ProcessingUpdate ,
3035 ProcessingUpdateStatusCreate ,
3136 ProcessingUpdateStatusUpdate ,
37+ PromptPreprocessingAbortedDict ,
3238 PromptPreprocessingCompleteDict ,
3339 PromptPreprocessingErrorDict ,
3440 PromptPreprocessingRequest ,
4147 AsyncSessionPlugins ,
4248 HookController ,
4349 SendMessageCallback ,
50+ ServerRequestError ,
4451 StatusBlockController ,
4552 TPluginConfigSchema ,
4653 TGlobalConfigSchema ,
4754)
4855
49-
5056# Available as lmstudio.plugin.hooks.*
5157__all__ = [
5258 "PromptPreprocessorController" ,
@@ -88,7 +94,7 @@ def iter_message_events(
8894 case None :
8995 # Server can only terminate the link by closing the websocket
9096 pass
91- case {"type" : "abort" , "task_id " : str (task_id )}:
97+ case {"type" : "abort" , "taskId " : str (task_id )}:
9298 yield PromptPreprocessingAbortEvent (task_id )
9399 case {"type" : "preprocess" } as request_dict :
94100 parsed_request = PromptPreprocessingRequest ._from_any_api_dict (
@@ -101,10 +107,10 @@ def iter_message_events(
101107 def handle_rx_event (self , event : PromptPreprocessingRxEvent ) -> None :
102108 match event :
103109 case PromptPreprocessingAbortEvent (task_id ):
104- self ._logger .info (f"Aborting { task_id } " , task_id = task_id )
110+ self ._logger .debug (f"Aborting { task_id } " , task_id = task_id )
105111 case PromptPreprocessingRequestEvent (request ):
106112 task_id = request .task_id
107- self ._logger .info (
113+ self ._logger .debug (
108114 "Received prompt preprocessing request" , task_id = task_id
109115 )
110116 case ChannelFinishedEvent (_):
@@ -198,16 +204,18 @@ class PromptPreprocessor(Generic[TPluginConfigSchema, TGlobalConfigSchema]):
198204 def __post_init__ (self ) -> None :
199205 self ._logger = logger = new_logger (__name__ )
200206 logger .update_context (plugin_name = self .plugin_name )
207+ self ._abort_events : dict [str , asyncio .Event ] = {}
201208
202209 async def process_requests (
203210 self , session : AsyncSessionPlugins , notify_ready : Callable [[], Any ]
204211 ) -> None :
212+ """Create plugin channel and wait for server requests."""
205213 logger = self ._logger
206214 endpoint = PromptPreprocessingEndpoint ()
207215 async with session ._create_channel (endpoint ) as channel :
208216 notify_ready ()
209217 logger .info ("Opened channel to receive prompt preprocessing requests..." )
210- send_cb = channel .send_message
218+ send_message = channel .send_message
211219 async with create_task_group () as tg :
212220 logger .debug ("Waiting for prompt preprocessing requests..." )
213221 async for contents in channel .rx_stream ():
@@ -218,6 +226,10 @@ async def process_requests(
218226 logger .debug ("Handling prompt preprocessing channel event" )
219227 endpoint .handle_rx_event (event )
220228 match event :
229+ case PromptPreprocessingAbortEvent ():
230+ await self ._abort_hook_invocation (
231+ event .arg , send_message
232+ )
221233 case PromptPreprocessingRequestEvent ():
222234 logger .debug (
223235 "Running prompt preprocessing request hook"
@@ -228,28 +240,76 @@ async def process_requests(
228240 self .plugin_config_schema ,
229241 self .global_config_schema ,
230242 )
231- tg .start_soon (self ._invoke_hook , ctl , send_cb )
243+ tg .start_soon (self ._invoke_hook , ctl , send_message )
232244 if endpoint .is_finished :
233245 break
234246
247+ async def _abort_hook_invocation (
248+ self , task_id : str , send_response : SendMessageCallback
249+ ) -> None :
250+ """Abort the specified hook invocation (if it is still running)."""
251+ abort_event = self ._abort_events .get (task_id , None )
252+ if abort_event is not None :
253+ abort_event .set ()
254+ response = PromptPreprocessingAbortedDict (
255+ type = "aborted" ,
256+ taskId = task_id ,
257+ )
258+ await send_response (response )
259+
260+ async def _cancel_on_event (
261+ self , tg : TaskGroup , event : asyncio .Event , message : str
262+ ) -> None :
263+ await event .wait ()
264+ self ._logger .info (message )
265+ tg .cancel_scope .cancel ()
266+
267+ @asynccontextmanager
268+ async def _registered_hook_invocation (
269+ self , task_id : str
270+ ) -> AsyncIterator [asyncio .Event ]:
271+ logger = self ._logger
272+ abort_events = self ._abort_events
273+ if task_id in abort_events :
274+ err_msg = f"Hook invocation already in progress for { task_id } "
275+ raise ServerRequestError (err_msg )
276+ abort_events [task_id ] = abort_event = asyncio .Event ()
277+ try :
278+ async with create_task_group () as tg :
279+ tg .start_soon (
280+ self ._cancel_on_event ,
281+ tg ,
282+ abort_event ,
283+ f"Aborting request { task_id } " ,
284+ )
285+ logger .info (f"Processing request { task_id } " )
286+ yield abort_event
287+ tg .cancel_scope .cancel ()
288+ finally :
289+ abort_events .pop (task_id , None )
290+ if abort_event .is_set ():
291+ completion_message = f"Aborted request { task_id } "
292+ else :
293+ completion_message = f"Processed request { task_id } "
294+ logger .info (completion_message )
295+
235296 async def _invoke_hook (
236297 self ,
237298 ctl : PromptPreprocessorController [TPluginConfigSchema , TGlobalConfigSchema ],
238299 send_response : SendMessageCallback ,
239300 ) -> None :
240301 logger = self ._logger
241- request = ctl .request
242- message = request .input
243- expected_cls = UserMessage
244- if not isinstance (message , expected_cls ):
245- logger .error (
246- f"Received { type (message ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
247- )
248- return
302+ task_id = ctl .task_id
303+ message = ctl .request .input
249304 error_details : SerializedLMSExtendedErrorDict | None = None
250305 response_dict : UserMessageDict
306+ expected_cls = UserMessage
251307 try :
252- response = await self .hook_impl (ctl , message )
308+ if not isinstance (message , expected_cls ):
309+ err_msg = f"Received { type (message ).__name__ !r} ({ expected_cls .__name__ !r} expected)"
310+ raise ServerRequestError (err_msg )
311+ async with self ._registered_hook_invocation (task_id ) as abort_event :
312+ response = await self .hook_impl (ctl , message )
253313 except Exception as exc :
254314 err_msg = "Error calling prompt preprocessing hook"
255315 logger .error (err_msg , exc_info = True , exc = repr (exc ))
@@ -259,8 +319,11 @@ async def _invoke_hook(
259319 cause = ui_cause , stack = "\n " .join (format_tb (exc .__traceback__ ))
260320 )
261321 else :
322+ if abort_event .is_set ():
323+ # Processing was aborted by the server, skip sending a response
324+ return
262325 if response is None :
263- # No change to message
326+ logger . debug ( " No changes made to preprocessed prompt" )
264327 response_dict = message .to_dict ()
265328 else :
266329 logger .debug (
@@ -291,13 +354,13 @@ async def _invoke_hook(
291354 error_details .update (common_error_args )
292355 channel_message = PromptPreprocessingErrorDict (
293356 type = "error" ,
294- taskId = request . task_id ,
357+ taskId = task_id ,
295358 error = error_details ,
296359 )
297360 else :
298361 channel_message = PromptPreprocessingCompleteDict (
299362 type = "complete" ,
300- taskId = request . task_id ,
363+ taskId = task_id ,
301364 processed = response_dict ,
302365 )
303366 await send_response (channel_message )
0 commit comments