Skip to content

Commit e9755f6

Browse files
committed
Prompt preprocessing is user-messages only
1 parent e7782ac commit e9755f6

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

examples/plugins/prompt-prefix/src/plugin.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44

55
from lmstudio.plugin import BaseConfigSchema, PromptPreprocessorController, config_field
6-
from lmstudio import AnyChatMessage, AnyChatMessageDict, TextDataDict
6+
from lmstudio import UserMessage, UserMessageDict, TextDataDict
77

88

99
# Assigning ConfigSchema = SomeOtherSchemaClass also works
@@ -36,11 +36,9 @@ class GlobalConfigSchema(BaseConfigSchema):
3636
# Assigning preprocess_prompt = some_other_callable also works
3737
async def preprocess_prompt(
3838
ctl: PromptPreprocessorController[ConfigSchema, GlobalConfigSchema],
39-
message: AnyChatMessage,
40-
) -> AnyChatMessageDict | None:
39+
message: UserMessage,
40+
) -> UserMessageDict | None:
4141
"""Naming the function 'preprocess_prompt' implicitly registers it."""
42-
if message.role != "user":
43-
return None
4442
print(f"Running prompt preprocessor hook from {__file__} with {ctl.plugin_config}")
4543
if ctl.global_config.enable_inplace_status_demo:
4644
# Run an in-place status prompt update demonstration

src/lmstudio/history.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ToolCallResultDataDict,
6060
ToolResultMessage,
6161
UserMessage,
62+
UserMessageDict,
6263
)
6364

6465
__all__ = [
@@ -84,6 +85,7 @@
8485
"ToolResultMessage",
8586
"UserMessage",
8687
"UserMessageContent",
88+
"UserMessageDict",
8789
]
8890

8991
# A note on terminology:

src/lmstudio/plugin/hooks/prompt_preprocessor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
Iterable,
1212
TypeAlias,
1313
assert_never,
14-
get_args as get_type_args,
1514
)
1615

1716
from anyio import create_task_group
1817

1918
from ..._logging import get_logger
2019
from ...schemas import DictObject, EmptyDict, ValidationError
21-
from ...history import AnyChatMessage, AnyChatMessageDict
20+
from ...history import UserMessage, UserMessageDict
2221
from ...json_api import (
2322
ChannelCommonRxEvent,
2423
ChannelEndpoint,
@@ -182,8 +181,8 @@ async def notify_done(self, message: str) -> None:
182181

183182

184183
PromptPreprocessorHook = Callable[
185-
[PromptPreprocessorController[Any, Any], AnyChatMessage],
186-
Awaitable[AnyChatMessage | AnyChatMessageDict | None],
184+
[PromptPreprocessorController[Any, Any], UserMessage],
185+
Awaitable[UserMessage | UserMessageDict | None],
187186
]
188187

189188

@@ -203,11 +202,17 @@ async def run_prompt_preprocessor(
203202

204203
async def _invoke_hook(request: PromptPreprocessingRequest) -> None:
205204
message = request.input
205+
expected_cls = UserMessage
206+
if not isinstance(message, expected_cls):
207+
logger.error(
208+
f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)"
209+
)
210+
return
206211
hook_controller = PromptPreprocessorController(
207212
session, request, plugin_config_schema, global_config_schema
208213
)
209214
error_details: SerializedLMSExtendedErrorDict | None = None
210-
response_dict: AnyChatMessageDict
215+
response_dict: UserMessageDict
211216
try:
212217
response = await hook_impl(hook_controller, message)
213218
except asyncio.CancelledError:
@@ -227,27 +232,23 @@ async def _invoke_hook(request: PromptPreprocessingRequest) -> None:
227232
response_dict = message.to_dict()
228233
else:
229234
logger.debug(
230-
f"Validating prompt preprocessing response: {response!r}"
235+
"Validating prompt preprocessing response", response=response
231236
)
232-
response_cls = type(message)
233237
if isinstance(response, dict):
234-
# Parse the response to ensure validity client side,
235-
# otherwise serialising the message may fail and crash the plugin
236238
try:
237-
# Response should have the same role as the received message
238-
parsed_response = load_struct(response, response_cls)
239+
parsed_response = load_struct(response, expected_cls)
239240
except ValidationError as exc:
240-
err_msg = f"Failed to parse prompt preprocessing response as {response_cls.__name__}"
241+
err_msg = f"Failed to parse prompt preprocessing response as {expected_cls.__name__}"
241242
logger.error(err_msg, exc_info=True, exc=repr(exc))
242243
error_details = SerializedLMSExtendedErrorDict(
243244
cause=err_msg
244245
)
245246
else:
246247
response_dict = parsed_response.to_dict()
247-
elif isinstance(response, get_type_args(AnyChatMessage)):
248+
elif isinstance(response, UserMessage):
248249
response_dict = response.to_dict()
249250
else:
250-
err_msg = f"Prompt preprocessing hook returned {type(response).__name__!r} ({response_cls.__name__!r} expected)"
251+
err_msg = f"Prompt preprocessing hook returned {type(response).__name__!r} ({expected_cls.__name__!r} expected)"
251252
logger.error(err_msg)
252253
error_details = SerializedLMSExtendedErrorDict(cause=err_msg)
253254
channel_message: DictObject

0 commit comments

Comments
 (0)