Skip to content

Commit e538f2a

Browse files
committed
Add abort request processing
1 parent 048fbc5 commit e538f2a

File tree

7 files changed

+121
-35
lines changed

7 files changed

+121
-35
lines changed

examples/plugins/prompt-prefix/src/plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ async def preprocess_prompt(
3939
message: UserMessage,
4040
) -> UserMessageDict | None:
4141
"""Naming the function 'preprocess_prompt' implicitly registers it."""
42-
print(f"Running prompt preprocessor hook from {__file__} with {ctl.plugin_config}")
4342
if ctl.global_config.enable_inplace_status_demo:
4443
# Run an in-place status prompt update demonstration
4544
status_block = await ctl.notify_start("Starting task (shows a static icon).")
4645
status_updates = (
4746
(status_block.notify_working, "Task in progress (shows a dynamic icon)."),
47+
(status_block.notify_waiting, "Task is blocked (shows a static icon)."),
4848
(status_block.notify_error, "Reporting an error status."),
4949
(status_block.notify_canceled, "Reporting cancellation."),
5050
(
@@ -55,9 +55,11 @@ async def preprocess_prompt(
5555
status_duration = ctl.global_config.inplace_status_duration / len(
5656
status_updates
5757
)
58-
for notification, status_text in status_updates:
59-
await asyncio.sleep(status_duration)
60-
await notification(status_text)
58+
async with status_block.notify_aborted("Task genuinely cancelled."):
59+
for notification, status_text in status_updates:
60+
await asyncio.sleep(status_duration)
61+
await notification(status_text)
62+
6163
modified_message = message.to_dict()
6264
# Add a prefix to all user messages
6365
prefix_text = ctl.plugin_config.prefix

sdk-schema/sync-sdk-schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,11 @@ def _infer_schema_unions() -> None:
363363
"LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict",
364364
"RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest",
365365
"RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict",
366-
# Prettier plugin channel type names
366+
# Prettier plugin channel message names
367367
"PluginsChannelSetPromptPreprocessorToClientPacketPreprocess": "PromptPreprocessingRequest",
368368
"PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict": "PromptPreprocessingRequestDict",
369+
"PluginsChannelSetPromptPreprocessorToServerPacketAborted": "PromptPreprocessingAborted",
370+
"PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict": "PromptPreprocessingAbortedDict",
369371
"PluginsChannelSetPromptPreprocessorToServerPacketComplete": "PromptPreprocessingComplete",
370372
"PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict": "PromptPreprocessingCompleteDict",
371373
"PluginsChannelSetPromptPreprocessorToServerPacketError": "PromptPreprocessingError",

src/lmstudio/_sdk_models/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,6 @@
420420
"PluginsChannelSetPredictionLoopHandlerToServerPacketErrorDict",
421421
"PluginsChannelSetPromptPreprocessorToClientPacketAbort",
422422
"PluginsChannelSetPromptPreprocessorToClientPacketAbortDict",
423-
"PluginsChannelSetPromptPreprocessorToServerPacketAborted",
424-
"PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict",
425423
"PluginsChannelSetToolsProviderToClientPacketAbortToolCall",
426424
"PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict",
427425
"PluginsChannelSetToolsProviderToClientPacketCallTool",
@@ -522,6 +520,8 @@
522520
"ProcessingUpdateToolStatusCreateDict",
523521
"ProcessingUpdateToolStatusUpdate",
524522
"ProcessingUpdateToolStatusUpdateDict",
523+
"PromptPreprocessingAborted",
524+
"PromptPreprocessingAbortedDict",
525525
"PromptPreprocessingComplete",
526526
"PromptPreprocessingCompleteDict",
527527
"PromptPreprocessingError",
@@ -5061,8 +5061,8 @@ class PluginsChannelSetPromptPreprocessorToClientPacketAbortDict(TypedDict):
50615061
taskId: str
50625062

50635063

5064-
class PluginsChannelSetPromptPreprocessorToServerPacketAborted(
5065-
LMStudioStruct["PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict"],
5064+
class PromptPreprocessingAborted(
5065+
LMStudioStruct["PromptPreprocessingAbortedDict"],
50665066
kw_only=True,
50675067
tag_field="type",
50685068
tag="aborted",
@@ -5071,7 +5071,7 @@ class PluginsChannelSetPromptPreprocessorToServerPacketAborted(
50715071
task_id: str = field(name="taskId")
50725072

50735073

5074-
class PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict(TypedDict):
5074+
class PromptPreprocessingAbortedDict(TypedDict):
50755075
"""Corresponding typed dictionary definition for PluginsChannelSetPromptPreprocessorToServerPacketAborted.
50765076
50775077
NOTE: Multi-word keys are defined using their camelCase form,
@@ -10199,14 +10199,12 @@ class PseudoPluginsChannelRegisterDevelopmentPluginDict(TypedDict):
1019910199
| PluginsChannelSetPromptPreprocessorToClientPacketAbortDict
1020010200
)
1020110201
PluginsChannelSetPromptPreprocessorToServerPacket = (
10202-
PromptPreprocessingComplete
10203-
| PluginsChannelSetPromptPreprocessorToServerPacketAborted
10204-
| PromptPreprocessingError
10202+
PromptPreprocessingComplete | PromptPreprocessingAborted | PromptPreprocessingError
1020510203
)
1020610204
PluginsChannelSetPromptPreprocessorToServerPacketDict = (
1020710205
PromptPreprocessingErrorDict
1020810206
| PromptPreprocessingCompleteDict
10209-
| PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict
10207+
| PromptPreprocessingAbortedDict
1021010208
)
1021110209

1021210210

src/lmstudio/_ws_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ async def _log_thread_execution(self) -> None:
290290
try:
291291
# Run the event loop until termination is requested
292292
await never_set.wait()
293-
except asyncio.CancelledError:
293+
except (asyncio.CancelledError, GeneratorExit):
294294
raise
295295
except BaseException:
296296
err_msg = "Terminating websocket thread due to exception"
@@ -359,7 +359,7 @@ async def _logged_ws_handler(self) -> None:
359359
self._logger.info("Websocket handling task started")
360360
try:
361361
await self._handle_ws()
362-
except asyncio.CancelledError:
362+
except (asyncio.CancelledError, GeneratorExit):
363363
raise
364364
except BaseException:
365365
err_msg = "Terminating websocket task due to exception"

src/lmstudio/plugin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# Prompt preprocessing hook
3434
# * [DONE] emit a status notification block when the demo plugin fires
3535
# * [DONE] add a global plugin config to control the in-place status update demo
36-
# * handle "Abort" requests from server (including sending "Aborted" responses)
36+
# * [DONE] handle "Abort" requests from server (including sending "Aborted" responses)
3737
# * [DONE] catch hook invocation failures and send "Error" responses
3838
# * [DONE] this includes adding runtime checks for the hook returning the wrong type
3939
#

src/lmstudio/plugin/hooks/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
"""Common utilities to invoke and support plugin hook implementations."""
22

3+
import asyncio
4+
5+
from contextlib import asynccontextmanager
36
from datetime import datetime, UTC
47
from pathlib import Path
58
from random import randrange
69
from typing import (
710
Any,
11+
AsyncIterator,
812
Awaitable,
913
Callable,
1014
Generic,
1115
TypeAlias,
1216
TypeVar,
1317
)
1418

19+
from anyio import move_on_after
20+
1521
from ...async_api import AsyncSession
1622
from ...schemas import DictObject
1723
from ..._sdk_models import (
@@ -48,6 +54,10 @@ class AsyncSessionPlugins(AsyncSession):
4854
SendMessageCallback: TypeAlias = Callable[[DictObject], Awaitable[Any]]
4955

5056

57+
class ServerRequestError(RuntimeError):
58+
"""Plugin received an invalid request from the API server."""
59+
60+
5161
class HookController(Generic[TPluginRequest, TPluginConfigSchema, TGlobalConfigSchema]):
5262
"""Common base class for plugin hook API access controllers."""
5363

@@ -117,3 +127,14 @@ async def notify_canceled(self, message: str) -> None:
117127
async def notify_done(self, message: str) -> None:
118128
"""Report task completion in the status block."""
119129
await self._update_ui(self._id, "done", message)
130+
131+
@asynccontextmanager
132+
async def notify_aborted(self, message: str) -> AsyncIterator[None]:
133+
"""Report asyncio.CancelledError as cancellation in the status block."""
134+
try:
135+
yield
136+
except asyncio.CancelledError:
137+
# Allow the notification to be sent, but don't necessarily wait for the reply
138+
with move_on_after(0.2, shield=True):
139+
await self.notify_canceled(message)
140+
raise

src/lmstudio/plugin/hooks/prompt_preprocessor.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Invoking and supporting prompt preprocessor hook implementations."""
22

3+
import asyncio
4+
5+
from contextlib import asynccontextmanager
36
from dataclasses import dataclass
47
from traceback import format_tb
58
from typing import (
69
Any,
10+
AsyncIterator,
711
Awaitable,
812
Callable,
913
Generic,
@@ -13,6 +17,7 @@
1317
)
1418

1519
from anyio import create_task_group
20+
from anyio.abc import TaskGroup
1621

1722
from ..._logging import new_logger
1823
from ...schemas import DictObject, EmptyDict, ValidationError
@@ -29,6 +34,7 @@
2934
ProcessingUpdate,
3035
ProcessingUpdateStatusCreate,
3136
ProcessingUpdateStatusUpdate,
37+
PromptPreprocessingAbortedDict,
3238
PromptPreprocessingCompleteDict,
3339
PromptPreprocessingErrorDict,
3440
PromptPreprocessingRequest,
@@ -41,12 +47,12 @@
4147
AsyncSessionPlugins,
4248
HookController,
4349
SendMessageCallback,
50+
ServerRequestError,
4451
StatusBlockController,
4552
TPluginConfigSchema,
4653
TGlobalConfigSchema,
4754
)
4855

49-
5056
# Available as lmstudio.plugin.hooks.*
5157
__all__ = [
5258
"PromptPreprocessorController",
@@ -88,7 +94,7 @@ def iter_message_events(
8894
case None:
8995
# Server can only terminate the link by closing the websocket
9096
pass
91-
case {"type": "abort", "task_id": str(task_id)}:
97+
case {"type": "abort", "taskId": str(task_id)}:
9298
yield PromptPreprocessingAbortEvent(task_id)
9399
case {"type": "preprocess"} as request_dict:
94100
parsed_request = PromptPreprocessingRequest._from_any_api_dict(
@@ -101,10 +107,10 @@ def iter_message_events(
101107
def handle_rx_event(self, event: PromptPreprocessingRxEvent) -> None:
102108
match event:
103109
case PromptPreprocessingAbortEvent(task_id):
104-
self._logger.info(f"Aborting {task_id}", task_id=task_id)
110+
self._logger.debug(f"Aborting {task_id}", task_id=task_id)
105111
case PromptPreprocessingRequestEvent(request):
106112
task_id = request.task_id
107-
self._logger.info(
113+
self._logger.debug(
108114
"Received prompt preprocessing request", task_id=task_id
109115
)
110116
case ChannelFinishedEvent(_):
@@ -198,16 +204,18 @@ class PromptPreprocessor(Generic[TPluginConfigSchema, TGlobalConfigSchema]):
198204
def __post_init__(self) -> None:
199205
self._logger = logger = new_logger(__name__)
200206
logger.update_context(plugin_name=self.plugin_name)
207+
self._abort_events: dict[str, asyncio.Event] = {}
201208

202209
async def process_requests(
203210
self, session: AsyncSessionPlugins, notify_ready: Callable[[], Any]
204211
) -> None:
212+
"""Create plugin channel and wait for server requests."""
205213
logger = self._logger
206214
endpoint = PromptPreprocessingEndpoint()
207215
async with session._create_channel(endpoint) as channel:
208216
notify_ready()
209217
logger.info("Opened channel to receive prompt preprocessing requests...")
210-
send_cb = channel.send_message
218+
send_message = channel.send_message
211219
async with create_task_group() as tg:
212220
logger.debug("Waiting for prompt preprocessing requests...")
213221
async for contents in channel.rx_stream():
@@ -218,6 +226,10 @@ async def process_requests(
218226
logger.debug("Handling prompt preprocessing channel event")
219227
endpoint.handle_rx_event(event)
220228
match event:
229+
case PromptPreprocessingAbortEvent():
230+
await self._abort_hook_invocation(
231+
event.arg, send_message
232+
)
221233
case PromptPreprocessingRequestEvent():
222234
logger.debug(
223235
"Running prompt preprocessing request hook"
@@ -228,28 +240,76 @@ async def process_requests(
228240
self.plugin_config_schema,
229241
self.global_config_schema,
230242
)
231-
tg.start_soon(self._invoke_hook, ctl, send_cb)
243+
tg.start_soon(self._invoke_hook, ctl, send_message)
232244
if endpoint.is_finished:
233245
break
234246

247+
async def _abort_hook_invocation(
248+
self, task_id: str, send_response: SendMessageCallback
249+
) -> None:
250+
"""Abort the specified hook invocation (if it is still running)."""
251+
abort_event = self._abort_events.get(task_id, None)
252+
if abort_event is not None:
253+
abort_event.set()
254+
response = PromptPreprocessingAbortedDict(
255+
type="aborted",
256+
taskId=task_id,
257+
)
258+
await send_response(response)
259+
260+
async def _cancel_on_event(
261+
self, tg: TaskGroup, event: asyncio.Event, message: str
262+
) -> None:
263+
await event.wait()
264+
self._logger.info(message)
265+
tg.cancel_scope.cancel()
266+
267+
@asynccontextmanager
268+
async def _registered_hook_invocation(
269+
self, task_id: str
270+
) -> AsyncIterator[asyncio.Event]:
271+
logger = self._logger
272+
abort_events = self._abort_events
273+
if task_id in abort_events:
274+
err_msg = f"Hook invocation already in progress for {task_id}"
275+
raise ServerRequestError(err_msg)
276+
abort_events[task_id] = abort_event = asyncio.Event()
277+
try:
278+
async with create_task_group() as tg:
279+
tg.start_soon(
280+
self._cancel_on_event,
281+
tg,
282+
abort_event,
283+
f"Aborting request {task_id}",
284+
)
285+
logger.info(f"Processing request {task_id}")
286+
yield abort_event
287+
tg.cancel_scope.cancel()
288+
finally:
289+
abort_events.pop(task_id, None)
290+
if abort_event.is_set():
291+
completion_message = f"Aborted request {task_id}"
292+
else:
293+
completion_message = f"Processed request {task_id}"
294+
logger.info(completion_message)
295+
235296
async def _invoke_hook(
236297
self,
237298
ctl: PromptPreprocessorController[TPluginConfigSchema, TGlobalConfigSchema],
238299
send_response: SendMessageCallback,
239300
) -> None:
240301
logger = self._logger
241-
request = ctl.request
242-
message = request.input
243-
expected_cls = UserMessage
244-
if not isinstance(message, expected_cls):
245-
logger.error(
246-
f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)"
247-
)
248-
return
302+
task_id = ctl.task_id
303+
message = ctl.request.input
249304
error_details: SerializedLMSExtendedErrorDict | None = None
250305
response_dict: UserMessageDict
306+
expected_cls = UserMessage
251307
try:
252-
response = await self.hook_impl(ctl, message)
308+
if not isinstance(message, expected_cls):
309+
err_msg = f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)"
310+
raise ServerRequestError(err_msg)
311+
async with self._registered_hook_invocation(task_id) as abort_event:
312+
response = await self.hook_impl(ctl, message)
253313
except Exception as exc:
254314
err_msg = "Error calling prompt preprocessing hook"
255315
logger.error(err_msg, exc_info=True, exc=repr(exc))
@@ -259,8 +319,11 @@ async def _invoke_hook(
259319
cause=ui_cause, stack="\n".join(format_tb(exc.__traceback__))
260320
)
261321
else:
322+
if abort_event.is_set():
323+
# Processing was aborted by the server, skip sending a response
324+
return
262325
if response is None:
263-
# No change to message
326+
logger.debug("No changes made to preprocessed prompt")
264327
response_dict = message.to_dict()
265328
else:
266329
logger.debug(
@@ -291,13 +354,13 @@ async def _invoke_hook(
291354
error_details.update(common_error_args)
292355
channel_message = PromptPreprocessingErrorDict(
293356
type="error",
294-
taskId=request.task_id,
357+
taskId=task_id,
295358
error=error_details,
296359
)
297360
else:
298361
channel_message = PromptPreprocessingCompleteDict(
299362
type="complete",
300-
taskId=request.task_id,
363+
taskId=task_id,
301364
processed=response_dict,
302365
)
303366
await send_response(channel_message)

0 commit comments

Comments
 (0)