Skip to content

Commit 048fbc5

Browse files
committed
Move plugin hook management to a class
1 parent 0d9062f commit 048fbc5

File tree

3 files changed

+131
-93
lines changed

3 files changed

+131
-93
lines changed

src/lmstudio/plugin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# * [DONE] add a global plugin config to control the in-place status update demo
3636
# * handle "Abort" requests from server (including sending "Aborted" responses)
3737
# * [DONE] catch hook invocation failures and send "Error" responses
38-
# * [partial] this includes adding runtime checks for the hook returning the wrong type
38+
# * [DONE] this includes adding runtime checks for the hook returning the wrong type
3939
#
4040
# Token generator hook
4141
# * add an example plugin for this (probably proxying a remote LM Studio instance)

src/lmstudio/plugin/hooks/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from random import randrange
66
from typing import (
77
Any,
8+
Awaitable,
89
Callable,
910
Generic,
1011
TypeAlias,
1112
TypeVar,
1213
)
1314

1415
from ...async_api import AsyncSession
16+
from ...schemas import DictObject
1517
from ..._sdk_models import (
1618
# TODO: Define aliases at schema generation time
1719
PluginsChannelSetGeneratorToClientPacketGenerate as TokenGenerationRequest,
@@ -43,6 +45,7 @@ class AsyncSessionPlugins(AsyncSession):
4345
TPluginConfigSchema = TypeVar("TPluginConfigSchema", bound=BaseConfigSchema)
4446
TGlobalConfigSchema = TypeVar("TGlobalConfigSchema", bound=BaseConfigSchema)
4547
TConfig = TypeVar("TConfig", bound=BaseConfigSchema)
48+
SendMessageCallback: TypeAlias = Callable[[DictObject], Awaitable[Any]]
4649

4750

4851
class HookController(Generic[TPluginRequest, TPluginConfigSchema, TGlobalConfigSchema]):
@@ -57,6 +60,7 @@ def __init__(
5760
) -> None:
5861
"""Initialize common hook controller settings."""
5962
self.session = session
63+
self.request = request
6064
self.plugin_config = self._parse_config(
6165
request.plugin_config, plugin_config_schema
6266
)

src/lmstudio/plugin/hooks/prompt_preprocessor.py

Lines changed: 126 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Invoking and supporting prompt preprocessor hook implementations."""
22

3+
from dataclasses import dataclass
34
from traceback import format_tb
4-
55
from typing import (
66
Any,
77
Awaitable,
88
Callable,
9+
Generic,
910
Iterable,
1011
TypeAlias,
1112
assert_never,
@@ -39,6 +40,7 @@
3940
from .common import (
4041
AsyncSessionPlugins,
4142
HookController,
43+
SendMessageCallback,
4244
StatusBlockController,
4345
TPluginConfigSchema,
4446
TGlobalConfigSchema,
@@ -125,6 +127,7 @@ def __init__(
125127
) -> None:
126128
"""Initialize prompt preprocessor hook controller."""
127129
super().__init__(session, request, plugin_config_schema, global_config_schema)
130+
self.task_id = request.task_id
128131
self.pci = request.pci
129132
self.token = request.token
130133

@@ -182,6 +185,124 @@ async def notify_done(self, message: str) -> None:
182185
]
183186

184187

188+
# TODO: Define a common "PluginHookHandler" base class
189+
@dataclass()
190+
class PromptPreprocessor(Generic[TPluginConfigSchema, TGlobalConfigSchema]):
191+
"""Handle accepting prompt preprocessing requests."""
192+
193+
plugin_name: str
194+
hook_impl: PromptPreprocessorHook
195+
plugin_config_schema: type[TPluginConfigSchema]
196+
global_config_schema: type[TGlobalConfigSchema]
197+
198+
def __post_init__(self) -> None:
199+
self._logger = logger = new_logger(__name__)
200+
logger.update_context(plugin_name=self.plugin_name)
201+
202+
async def process_requests(
203+
self, session: AsyncSessionPlugins, notify_ready: Callable[[], Any]
204+
) -> None:
205+
logger = self._logger
206+
endpoint = PromptPreprocessingEndpoint()
207+
async with session._create_channel(endpoint) as channel:
208+
notify_ready()
209+
logger.info("Opened channel to receive prompt preprocessing requests...")
210+
send_cb = channel.send_message
211+
async with create_task_group() as tg:
212+
logger.debug("Waiting for prompt preprocessing requests...")
213+
async for contents in channel.rx_stream():
214+
logger.debug(
215+
f"Handling prompt preprocessing channel message: {contents}"
216+
)
217+
for event in endpoint.iter_message_events(contents):
218+
logger.debug("Handling prompt preprocessing channel event")
219+
endpoint.handle_rx_event(event)
220+
match event:
221+
case PromptPreprocessingRequestEvent():
222+
logger.debug(
223+
"Running prompt preprocessing request hook"
224+
)
225+
ctl = PromptPreprocessorController(
226+
session,
227+
event.arg,
228+
self.plugin_config_schema,
229+
self.global_config_schema,
230+
)
231+
tg.start_soon(self._invoke_hook, ctl, send_cb)
232+
if endpoint.is_finished:
233+
break
234+
235+
async def _invoke_hook(
236+
self,
237+
ctl: PromptPreprocessorController[TPluginConfigSchema, TGlobalConfigSchema],
238+
send_response: SendMessageCallback,
239+
) -> None:
240+
logger = self._logger
241+
request = ctl.request
242+
message = request.input
243+
expected_cls = UserMessage
244+
if not isinstance(message, expected_cls):
245+
logger.error(
246+
f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)"
247+
)
248+
return
249+
error_details: SerializedLMSExtendedErrorDict | None = None
250+
response_dict: UserMessageDict
251+
try:
252+
response = await self.hook_impl(ctl, message)
253+
except Exception as exc:
254+
err_msg = "Error calling prompt preprocessing hook"
255+
logger.error(err_msg, exc_info=True, exc=repr(exc))
256+
# TODO: Determine if it's worth sending the stack trace to the server
257+
ui_cause = f"{err_msg}\n({type(exc).__name__}: {exc})"
258+
error_details = SerializedLMSExtendedErrorDict(
259+
cause=ui_cause, stack="\n".join(format_tb(exc.__traceback__))
260+
)
261+
else:
262+
if response is None:
263+
# No change to message
264+
response_dict = message.to_dict()
265+
else:
266+
logger.debug(
267+
"Validating prompt preprocessing response", response=response
268+
)
269+
if isinstance(response, dict):
270+
try:
271+
parsed_response = load_struct(response, expected_cls)
272+
except ValidationError as exc:
273+
err_msg = f"Failed to parse prompt preprocessing response as {expected_cls.__name__}\n({exc})"
274+
logger.error(err_msg)
275+
error_details = SerializedLMSExtendedErrorDict(cause=err_msg)
276+
else:
277+
response_dict = parsed_response.to_dict()
278+
elif isinstance(response, UserMessage):
279+
response_dict = response.to_dict()
280+
else:
281+
err_msg = f"Prompt preprocessing hook returned {type(response).__name__!r} ({expected_cls.__name__!r} expected)"
282+
logger.error(err_msg)
283+
error_details = SerializedLMSExtendedErrorDict(cause=err_msg)
284+
channel_message: DictObject
285+
if error_details is not None:
286+
error_title = f"Prompt preprocessing error in plugin {self.plugin_name!r}"
287+
common_error_args: SerializedLMSExtendedErrorDict = {
288+
"title": error_title,
289+
"rootTitle": error_title,
290+
}
291+
error_details.update(common_error_args)
292+
channel_message = PromptPreprocessingErrorDict(
293+
type="error",
294+
taskId=request.task_id,
295+
error=error_details,
296+
)
297+
else:
298+
channel_message = PromptPreprocessingCompleteDict(
299+
type="complete",
300+
taskId=request.task_id,
301+
processed=response_dict,
302+
)
303+
await send_response(channel_message)
304+
305+
185306
async def run_prompt_preprocessor(
186307
plugin_name: str,
187308
hook_impl: PromptPreprocessorHook,
@@ -191,94 +312,7 @@ async def run_prompt_preprocessor(
191312
notify_ready: Callable[[], Any],
192313
) -> None:
193314
"""Accept prompt preprocessing requests."""
194-
logger = new_logger(__name__)
195-
logger.update_context(plugin_name=plugin_name)
196-
endpoint = PromptPreprocessingEndpoint()
197-
async with session._create_channel(endpoint) as channel:
198-
notify_ready()
199-
logger.info("Opened channel to receive prompt preprocessing requests...")
200-
201-
async def _invoke_hook(request: PromptPreprocessingRequest) -> None:
202-
message = request.input
203-
expected_cls = UserMessage
204-
if not isinstance(message, expected_cls):
205-
logger.error(
206-
f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)"
207-
)
208-
return
209-
hook_controller = PromptPreprocessorController(
210-
session, request, plugin_config_schema, global_config_schema
211-
)
212-
error_details: SerializedLMSExtendedErrorDict | None = None
213-
response_dict: UserMessageDict
214-
try:
215-
response = await hook_impl(hook_controller, message)
216-
except Exception as exc:
217-
err_msg = "Error calling prompt preprocessing hook"
218-
logger.error(err_msg, exc_info=True, exc=repr(exc))
219-
# TODO: Determine if it's worth sending the stack trace to the server
220-
ui_cause = f"{err_msg}\n({type(exc).__name__}: {exc})"
221-
error_details = SerializedLMSExtendedErrorDict(
222-
cause=ui_cause, stack="\n".join(format_tb(exc.__traceback__))
223-
)
224-
else:
225-
if response is None:
226-
# No change to message
227-
response_dict = message.to_dict()
228-
else:
229-
logger.debug(
230-
"Validating prompt preprocessing response", response=response
231-
)
232-
if isinstance(response, dict):
233-
try:
234-
parsed_response = load_struct(response, expected_cls)
235-
except ValidationError as exc:
236-
err_msg = f"Failed to parse prompt preprocessing response as {expected_cls.__name__}\n({exc})"
237-
logger.error(err_msg)
238-
error_details = SerializedLMSExtendedErrorDict(
239-
cause=err_msg
240-
)
241-
else:
242-
response_dict = parsed_response.to_dict()
243-
elif isinstance(response, UserMessage):
244-
response_dict = response.to_dict()
245-
else:
246-
err_msg = f"Prompt preprocessing hook returned {type(response).__name__!r} ({expected_cls.__name__!r} expected)"
247-
logger.error(err_msg)
248-
error_details = SerializedLMSExtendedErrorDict(cause=err_msg)
249-
channel_message: DictObject
250-
if error_details is not None:
251-
error_title = f"Prompt preprocessing error in plugin {plugin_name!r}"
252-
common_error_args: SerializedLMSExtendedErrorDict = {
253-
"title": error_title,
254-
"rootTitle": error_title,
255-
}
256-
error_details.update(common_error_args)
257-
channel_message = PromptPreprocessingErrorDict(
258-
type="error",
259-
taskId=request.task_id,
260-
error=error_details,
261-
)
262-
else:
263-
channel_message = PromptPreprocessingCompleteDict(
264-
type="complete",
265-
taskId=request.task_id,
266-
processed=response_dict,
267-
)
268-
await channel.send_message(channel_message)
269-
270-
async with create_task_group() as tg:
271-
logger.debug("Waiting for prompt preprocessing requests...")
272-
async for contents in channel.rx_stream():
273-
logger.debug(
274-
f"Handling prompt preprocessing channel message: {contents}"
275-
)
276-
for event in endpoint.iter_message_events(contents):
277-
logger.debug("Handling prompt preprocessing channel event")
278-
endpoint.handle_rx_event(event)
279-
match event:
280-
case PromptPreprocessingRequestEvent():
281-
logger.debug("Running prompt preprocessing request hook")
282-
tg.start_soon(_invoke_hook, event.arg)
283-
if endpoint.is_finished:
284-
break
315+
prompt_preprocessor = PromptPreprocessor(
316+
plugin_name, hook_impl, plugin_config_schema, global_config_schema
317+
)
318+
await prompt_preprocessor.process_requests(session, notify_ready)

0 commit comments

Comments
 (0)