diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 82fa228..2753dff 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,7 +41,8 @@ jobs: max-parallel: 8 matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] - # There's no platform specific SDK code, but explicitly check Windows anyway + # There's no platform specific SDK code, but explicitly check Windows + # to ensure there aren't any inadvertent POSIX-only assumptions os: [ubuntu-22.04, windows-2022] # Check https://github.com/actions/action-versions/tree/main/config/actions diff --git a/examples/plugins/dice-tool/README.md b/examples/plugins/dice-tool/README.md new file mode 100644 index 0000000..2d3e1f2 --- /dev/null +++ b/examples/plugins/dice-tool/README.md @@ -0,0 +1,3 @@ +# `lmstudio/pydice` + +TODO: Example Python tools provider plugin diff --git a/examples/plugins/dice-tool/manifest.json b/examples/plugins/dice-tool/manifest.json new file mode 100644 index 0000000..787fcea --- /dev/null +++ b/examples/plugins/dice-tool/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "py-dice-tool", + "revision": 2 +} diff --git a/examples/plugins/dice-tool/src/plugin.py b/examples/plugins/dice-tool/src/plugin.py new file mode 100644 index 0000000..ea2edf2 --- /dev/null +++ b/examples/plugins/dice-tool/src/plugin.py @@ -0,0 +1,3 @@ +"""Example plugin that provide dice rolling tools.""" + +# Not yet implemented, currently used to check plugins with no hooks defined diff --git a/examples/plugins/prompt-prefix/README.md b/examples/plugins/prompt-prefix/README.md new file mode 100644 index 0000000..03fec2f --- /dev/null +++ b/examples/plugins/prompt-prefix/README.md @@ -0,0 +1,6 @@ +# `lmstudio/pyprompt` + +Python prompt preprocessing plugin example + +Note: there's no `python` runner in LM Studio yet, so use +`python -m lmstudio.plugin --dev path/to/plugin` to run a dev instance diff --git a/examples/plugins/prompt-prefix/manifest.json b/examples/plugins/prompt-prefix/manifest.json new file mode 100644 index 0000000..fc376f2 --- /dev/null +++ b/examples/plugins/prompt-prefix/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "prompt-prefix", + "revision": 1 +} diff --git a/examples/plugins/prompt-prefix/src/plugin.py b/examples/plugins/prompt-prefix/src/plugin.py new file mode 100644 index 0000000..4521ebb --- /dev/null +++ b/examples/plugins/prompt-prefix/src/plugin.py @@ -0,0 +1,76 @@ +"""Example plugin that adds a prefix to all user prompts.""" + +import asyncio + +from lmstudio.plugin import BaseConfigSchema, PromptPreprocessorController, config_field +from lmstudio import UserMessage, UserMessageDict, TextDataDict + + +# Assigning ConfigSchema = SomeOtherSchemaClass also works +class ConfigSchema(BaseConfigSchema): + """The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema.""" + + prefix: str = config_field( + label="Prefix to insert", + hint="This text will be inserted at the start of all user prompts", + default="And now for something completely different: ", + ) + + +# Assigning GlobalConfigSchema = SomeOtherGlobalSchemaClass also works +class GlobalConfigSchema(BaseConfigSchema): + """The name 'GlobalConfigSchema' implicitly registers this as the global plugin config schema.""" + + enable_inplace_status_demo: bool = config_field( + label="Enable in-place status demo", + hint="The plugin will run an in-place task status updating demo when invoked", + default=True, + ) + inplace_status_duration: float = config_field( + label="In-place status total duration (s)", + hint="The number of seconds to spend displaying the in-place task status update", + default=5.0, + ) + + +# Assigning preprocess_prompt = some_other_callable also works +async def preprocess_prompt( + ctl: PromptPreprocessorController[ConfigSchema, GlobalConfigSchema], + message: UserMessage, +) -> UserMessageDict | None: + """Naming the function 'preprocess_prompt' implicitly registers it.""" + if ctl.global_config.enable_inplace_status_demo: + # Run an in-place status prompt update demonstration + status_block = await ctl.notify_start("Starting task (shows a static icon).") + status_updates = ( + (status_block.notify_working, "Task in progress (shows a dynamic icon)."), + (status_block.notify_waiting, "Task is blocked (shows a static icon)."), + (status_block.notify_error, "Reporting an error status."), + (status_block.notify_canceled, "Reporting cancellation."), + ( + status_block.notify_done, + "In-place status update demonstration completed.", + ), + ) + status_duration = ctl.global_config.inplace_status_duration / len( + status_updates + ) + async with status_block.notify_aborted("Task genuinely cancelled."): + for notification, status_text in status_updates: + await asyncio.sleep(status_duration) + await notification(status_text) + + modified_message = message.to_dict() + # Add a prefix to all user messages + prefix_text = ctl.plugin_config.prefix + prefix: TextDataDict = { + "type": "text", + "text": prefix_text, + } + modified_message["content"] = [prefix, *modified_message["content"]] + # Demonstrate simple completion status reporting for non-blocking operations + await ctl.notify_done(f"Added prefix {prefix_text!r} to user message.") + return modified_message + + +print(f"{__name__} initialized from {__file__}") diff --git a/sdk-schema/sync-sdk-schema.py b/sdk-schema/sync-sdk-schema.py index 255483b..51f7a3e 100755 --- a/sdk-schema/sync-sdk-schema.py +++ b/sdk-schema/sync-sdk-schema.py @@ -363,6 +363,18 @@ def _infer_schema_unions() -> None: "LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict", "RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest", "RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict", + # Prettier plugin channel message names + "PluginsChannelSetPromptPreprocessorToClientPacketPreprocess": "PromptPreprocessingRequest", + "PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict": "PromptPreprocessingRequestDict", + "PluginsChannelSetPromptPreprocessorToServerPacketAborted": "PromptPreprocessingAborted", + "PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict": "PromptPreprocessingAbortedDict", + "PluginsChannelSetPromptPreprocessorToServerPacketComplete": "PromptPreprocessingComplete", + "PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict": "PromptPreprocessingCompleteDict", + "PluginsChannelSetPromptPreprocessorToServerPacketError": "PromptPreprocessingError", + "PluginsChannelSetPromptPreprocessorToServerPacketErrorDict": "PromptPreprocessingErrorDict", + # Prettier config handling type names + "LlmRpcGetLoadConfigReturns": "SerializedKVConfigSettings", + "LlmRpcGetLoadConfigReturnsDict": "SerializedKVConfigSettingsDict", } diff --git a/src/lmstudio/_kv_config.py b/src/lmstudio/_kv_config.py index a55e7b4..f010290 100644 --- a/src/lmstudio/_kv_config.py +++ b/src/lmstudio/_kv_config.py @@ -40,6 +40,7 @@ LlmSplitStrategy, LlmStructuredPredictionSetting, LlmStructuredPredictionSettingDict, + SerializedKVConfigSettings, ) @@ -324,7 +325,7 @@ def _invert_config_keymap(from_server: FromServerKeymap) -> ToServerKeymap: ) -def dict_from_kvconfig(config: KvConfig) -> DictObject: +def dict_from_kvconfig(config: KvConfig | SerializedKVConfigSettings) -> DictObject: return {kv.key: kv.value for kv in config.fields} diff --git a/src/lmstudio/_sdk_models/__init__.py b/src/lmstudio/_sdk_models/__init__.py index bd4e319..5c56e37 100644 --- a/src/lmstudio/_sdk_models/__init__.py +++ b/src/lmstudio/_sdk_models/__init__.py @@ -324,8 +324,6 @@ "LlmRpcCountTokensReturnsDict", "LlmRpcGetLoadConfigParameter", "LlmRpcGetLoadConfigParameterDict", - "LlmRpcGetLoadConfigReturns", - "LlmRpcGetLoadConfigReturnsDict", "LlmRpcGetModelInfoParameter", "LlmRpcGetModelInfoParameterDict", "LlmRpcPreloadDraftModelParameter", @@ -422,14 +420,6 @@ "PluginsChannelSetPredictionLoopHandlerToServerPacketErrorDict", "PluginsChannelSetPromptPreprocessorToClientPacketAbort", "PluginsChannelSetPromptPreprocessorToClientPacketAbortDict", - "PluginsChannelSetPromptPreprocessorToClientPacketPreprocess", - "PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict", - "PluginsChannelSetPromptPreprocessorToServerPacketAborted", - "PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict", - "PluginsChannelSetPromptPreprocessorToServerPacketComplete", - "PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict", - "PluginsChannelSetPromptPreprocessorToServerPacketError", - "PluginsChannelSetPromptPreprocessorToServerPacketErrorDict", "PluginsChannelSetToolsProviderToClientPacketAbortToolCall", "PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict", "PluginsChannelSetToolsProviderToClientPacketCallTool", @@ -530,6 +520,14 @@ "ProcessingUpdateToolStatusCreateDict", "ProcessingUpdateToolStatusUpdate", "ProcessingUpdateToolStatusUpdateDict", + "PromptPreprocessingAborted", + "PromptPreprocessingAbortedDict", + "PromptPreprocessingComplete", + "PromptPreprocessingCompleteDict", + "PromptPreprocessingError", + "PromptPreprocessingErrorDict", + "PromptPreprocessingRequest", + "PromptPreprocessingRequestDict", "PseudoDiagnostics", "PseudoDiagnosticsChannelStreamLogs", "PseudoDiagnosticsChannelStreamLogsDict", @@ -731,6 +729,8 @@ "SerializedKVConfigSchematicsDict", "SerializedKVConfigSchematicsField", "SerializedKVConfigSchematicsFieldDict", + "SerializedKVConfigSettings", + "SerializedKVConfigSettingsDict", "SerializedLMSExtendedError", "SerializedLMSExtendedErrorDict", "StatusStepState", @@ -5061,8 +5061,8 @@ class PluginsChannelSetPromptPreprocessorToClientPacketAbortDict(TypedDict): taskId: str -class PluginsChannelSetPromptPreprocessorToServerPacketAborted( - LMStudioStruct["PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict"], +class PromptPreprocessingAborted( + LMStudioStruct["PromptPreprocessingAbortedDict"], kw_only=True, tag_field="type", tag="aborted", @@ -5071,7 +5071,7 @@ class PluginsChannelSetPromptPreprocessorToServerPacketAborted( task_id: str = field(name="taskId") -class PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict(TypedDict): +class PromptPreprocessingAbortedDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetPromptPreprocessorToServerPacketAborted. NOTE: Multi-word keys are defined using their camelCase form, @@ -5082,8 +5082,8 @@ class PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict(TypedDict): taskId: str -class PluginsChannelSetPromptPreprocessorToServerPacketError( - LMStudioStruct["PluginsChannelSetPromptPreprocessorToServerPacketErrorDict"], +class PromptPreprocessingError( + LMStudioStruct["PromptPreprocessingErrorDict"], kw_only=True, tag_field="type", tag="error", @@ -5093,7 +5093,7 @@ class PluginsChannelSetPromptPreprocessorToServerPacketError( error: SerializedLMSExtendedError -class PluginsChannelSetPromptPreprocessorToServerPacketErrorDict(TypedDict): +class PromptPreprocessingErrorDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetPromptPreprocessorToServerPacketError. NOTE: Multi-word keys are defined using their camelCase form, @@ -7069,13 +7069,13 @@ class PseudoLlmRpcListLoadedDict(TypedDict): parameter: NotRequired[LlmRpcListLoadedParameter | None] -class LlmRpcGetLoadConfigReturns( - LMStudioStruct["LlmRpcGetLoadConfigReturnsDict"], kw_only=True +class SerializedKVConfigSettings( + LMStudioStruct["SerializedKVConfigSettingsDict"], kw_only=True ): fields: Fields -class LlmRpcGetLoadConfigReturnsDict(TypedDict): +class SerializedKVConfigSettingsDict(TypedDict): """Corresponding typed dictionary definition for LlmRpcGetLoadConfigReturns. NOTE: Multi-word keys are defined using their camelCase form, @@ -7864,8 +7864,8 @@ class LlmChannelPredictToClientPacketSuccess( type: ClassVar[Annotated[Literal["success"], Meta(title="Type")]] = "success" stats: LlmPredictionStats model_info: LlmInstanceInfo = field(name="modelInfo") - load_model_config: LlmRpcGetLoadConfigReturns = field(name="loadModelConfig") - prediction_config: LlmRpcGetLoadConfigReturns = field(name="predictionConfig") + load_model_config: SerializedKVConfigSettings = field(name="loadModelConfig") + prediction_config: SerializedKVConfigSettings = field(name="predictionConfig") class LlmChannelPredictToClientPacketSuccessDict(TypedDict): @@ -7878,8 +7878,8 @@ class LlmChannelPredictToClientPacketSuccessDict(TypedDict): type: Literal["success"] stats: LlmPredictionStatsDict modelInfo: LlmInstanceInfoDict - loadModelConfig: LlmRpcGetLoadConfigReturnsDict - predictionConfig: LlmRpcGetLoadConfigReturnsDict + loadModelConfig: SerializedKVConfigSettingsDict + predictionConfig: SerializedKVConfigSettingsDict class LlmChannelGenerateWithGeneratorToClientPacketToolCallGenerationEnd( @@ -7919,9 +7919,9 @@ class PluginsChannelSetPredictionLoopHandlerToClientPacketHandlePredictionLoop( "handlePredictionLoop" ) task_id: str = field(name="taskId") - config: LlmRpcGetLoadConfigReturns - plugin_config: LlmRpcGetLoadConfigReturns = field(name="pluginConfig") - global_plugin_config: LlmRpcGetLoadConfigReturns = field(name="globalPluginConfig") + config: SerializedKVConfigSettings + plugin_config: SerializedKVConfigSettings = field(name="pluginConfig") + global_plugin_config: SerializedKVConfigSettings = field(name="globalPluginConfig") working_directory_path: str | None = field(name="workingDirectoryPath") pci: str token: str @@ -7938,9 +7938,9 @@ class PluginsChannelSetPredictionLoopHandlerToClientPacketHandlePredictionLoopDi type: Literal["handlePredictionLoop"] taskId: str - config: LlmRpcGetLoadConfigReturnsDict - pluginConfig: LlmRpcGetLoadConfigReturnsDict - globalPluginConfig: LlmRpcGetLoadConfigReturnsDict + config: SerializedKVConfigSettingsDict + pluginConfig: SerializedKVConfigSettingsDict + globalPluginConfig: SerializedKVConfigSettingsDict workingDirectoryPath: NotRequired[str | None] pci: str token: str @@ -7955,8 +7955,8 @@ class PluginsChannelSetToolsProviderToClientPacketInitSession( type: ClassVar[Annotated[Literal["initSession"], Meta(title="Type")]] = ( "initSession" ) - plugin_config: LlmRpcGetLoadConfigReturns = field(name="pluginConfig") - global_plugin_config: LlmRpcGetLoadConfigReturns = field(name="globalPluginConfig") + plugin_config: SerializedKVConfigSettings = field(name="pluginConfig") + global_plugin_config: SerializedKVConfigSettings = field(name="globalPluginConfig") working_directory_path: str | None = field(name="workingDirectoryPath") session_id: str = field(name="sessionId") @@ -7969,8 +7969,8 @@ class PluginsChannelSetToolsProviderToClientPacketInitSessionDict(TypedDict): """ type: Literal["initSession"] - pluginConfig: LlmRpcGetLoadConfigReturnsDict - globalPluginConfig: LlmRpcGetLoadConfigReturnsDict + pluginConfig: SerializedKVConfigSettingsDict + globalPluginConfig: SerializedKVConfigSettingsDict workingDirectoryPath: NotRequired[str | None] sessionId: str @@ -8830,7 +8830,7 @@ class PseudoLlmRpcGetLoadConfig( LMStudioStruct["PseudoLlmRpcGetLoadConfigDict"], kw_only=True ): parameter: LlmRpcGetLoadConfigParameter - returns: LlmRpcGetLoadConfigReturns + returns: SerializedKVConfigSettings class PseudoLlmRpcGetLoadConfigDict(TypedDict): @@ -8841,7 +8841,7 @@ class PseudoLlmRpcGetLoadConfigDict(TypedDict): """ parameter: LlmRpcGetLoadConfigParameterDict - returns: LlmRpcGetLoadConfigReturnsDict + returns: SerializedKVConfigSettingsDict class LlmRpcTokenizeParameter( @@ -9525,8 +9525,8 @@ class PseudoPluginsChannelSetToolsProviderDict(TypedDict): toServerPacket: PluginsChannelSetToolsProviderToServerPacketDict -class PluginsChannelSetPromptPreprocessorToClientPacketPreprocess( - LMStudioStruct["PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict"], +class PromptPreprocessingRequest( + LMStudioStruct["PromptPreprocessingRequestDict"], kw_only=True, tag_field="type", tag="preprocess", @@ -9534,15 +9534,15 @@ class PluginsChannelSetPromptPreprocessorToClientPacketPreprocess( type: ClassVar[Annotated[Literal["preprocess"], Meta(title="Type")]] = "preprocess" task_id: str = field(name="taskId") input: AnyChatMessage - config: LlmRpcGetLoadConfigReturns - plugin_config: LlmRpcGetLoadConfigReturns = field(name="pluginConfig") - global_plugin_config: LlmRpcGetLoadConfigReturns = field(name="globalPluginConfig") + config: SerializedKVConfigSettings + plugin_config: SerializedKVConfigSettings = field(name="pluginConfig") + global_plugin_config: SerializedKVConfigSettings = field(name="globalPluginConfig") working_directory_path: str | None = field(name="workingDirectoryPath") pci: str token: str -class PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict(TypedDict): +class PromptPreprocessingRequestDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetPromptPreprocessorToClientPacketPreprocess. NOTE: Multi-word keys are defined using their camelCase form, @@ -9552,16 +9552,16 @@ class PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict(TypedDict) type: Literal["preprocess"] taskId: str input: AnyChatMessageDict - config: LlmRpcGetLoadConfigReturnsDict - pluginConfig: LlmRpcGetLoadConfigReturnsDict - globalPluginConfig: LlmRpcGetLoadConfigReturnsDict + config: SerializedKVConfigSettingsDict + pluginConfig: SerializedKVConfigSettingsDict + globalPluginConfig: SerializedKVConfigSettingsDict workingDirectoryPath: NotRequired[str | None] pci: str token: str -class PluginsChannelSetPromptPreprocessorToServerPacketComplete( - LMStudioStruct["PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict"], +class PromptPreprocessingComplete( + LMStudioStruct["PromptPreprocessingCompleteDict"], kw_only=True, tag_field="type", tag="complete", @@ -9571,7 +9571,7 @@ class PluginsChannelSetPromptPreprocessorToServerPacketComplete( processed: AnyChatMessage -class PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict(TypedDict): +class PromptPreprocessingCompleteDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetPromptPreprocessorToServerPacketComplete. NOTE: Multi-word keys are defined using their camelCase form, @@ -10192,22 +10192,19 @@ class PseudoPluginsChannelRegisterDevelopmentPluginDict(TypedDict): PluginsChannelSetPromptPreprocessorToClientPacket = ( - PluginsChannelSetPromptPreprocessorToClientPacketPreprocess - | PluginsChannelSetPromptPreprocessorToClientPacketAbort + PromptPreprocessingRequest | PluginsChannelSetPromptPreprocessorToClientPacketAbort ) PluginsChannelSetPromptPreprocessorToClientPacketDict = ( - PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict + PromptPreprocessingRequestDict | PluginsChannelSetPromptPreprocessorToClientPacketAbortDict ) PluginsChannelSetPromptPreprocessorToServerPacket = ( - PluginsChannelSetPromptPreprocessorToServerPacketComplete - | PluginsChannelSetPromptPreprocessorToServerPacketAborted - | PluginsChannelSetPromptPreprocessorToServerPacketError + PromptPreprocessingComplete | PromptPreprocessingAborted | PromptPreprocessingError ) PluginsChannelSetPromptPreprocessorToServerPacketDict = ( - PluginsChannelSetPromptPreprocessorToServerPacketErrorDict - | PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict - | PluginsChannelSetPromptPreprocessorToServerPacketAbortedDict + PromptPreprocessingErrorDict + | PromptPreprocessingCompleteDict + | PromptPreprocessingAbortedDict ) @@ -10242,8 +10239,8 @@ class PluginsChannelSetGeneratorToClientPacketGenerate( type: ClassVar[Annotated[Literal["generate"], Meta(title="Type")]] = "generate" task_id: str = field(name="taskId") input: PluginsRpcProcessingPullHistoryReturns - plugin_config: LlmRpcGetLoadConfigReturns = field(name="pluginConfig") - global_plugin_config: LlmRpcGetLoadConfigReturns = field(name="globalPluginConfig") + plugin_config: SerializedKVConfigSettings = field(name="pluginConfig") + global_plugin_config: SerializedKVConfigSettings = field(name="globalPluginConfig") tool_definitions: Sequence[LlmTool] = field(name="toolDefinitions") working_directory_path: str | None = field(name="workingDirectoryPath") @@ -10258,8 +10255,8 @@ class PluginsChannelSetGeneratorToClientPacketGenerateDict(TypedDict): type: Literal["generate"] taskId: str input: PluginsRpcProcessingPullHistoryReturnsDict - pluginConfig: LlmRpcGetLoadConfigReturnsDict - globalPluginConfig: LlmRpcGetLoadConfigReturnsDict + pluginConfig: SerializedKVConfigSettingsDict + globalPluginConfig: SerializedKVConfigSettingsDict toolDefinitions: Sequence[LlmToolFunctionDict] workingDirectoryPath: NotRequired[str | None] diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 8d33ec1..a239251 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -290,7 +290,7 @@ async def _log_thread_execution(self) -> None: try: # Run the event loop until termination is requested await never_set.wait() - except asyncio.CancelledError: + except (asyncio.CancelledError, GeneratorExit): raise except BaseException: err_msg = "Terminating websocket thread due to exception" @@ -359,7 +359,7 @@ async def _logged_ws_handler(self) -> None: self._logger.info("Websocket handling task started") try: await self._handle_ws() - except asyncio.CancelledError: + except (asyncio.CancelledError, GeneratorExit): raise except BaseException: err_msg = "Terminating websocket task due to exception" diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index ad1ce55..775e9ee 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -148,6 +148,11 @@ def get_creation_message(self) -> DictObject: """Get the message to send to create this channel.""" return self._api_channel.get_creation_message() + async def send_message(self, message: DictObject) -> None: + """Send given message on this channel.""" + wrapped_message = self._api_channel.wrap_message(message) + await self._send_json(wrapped_message) + async def cancel(self) -> None: """Cancel the channel.""" if self._is_finished: @@ -362,7 +367,7 @@ async def _notify_client_termination(self) -> int: for rx_queue in self._mux.all_queues(): await rx_queue.put(None) num_clients += 1 - self._logger.info( + self._logger.debug( f"Notified {num_clients} clients of websocket termination", num_clients=num_clients, ) @@ -1406,7 +1411,7 @@ def __init__(self, api_host: str | None = None) -> None: # However, lazy connections also don't work due to structured concurrency. # For now, all sessions are opened eagerly by the client # TODO: provide a way to selectively exclude unnecessary client sessions - _ALL_SESSIONS = ( + _ALL_SESSIONS: tuple[Type[AsyncSession], ...] = ( AsyncSessionEmbedding, _AsyncSessionFiles, AsyncSessionLlm, diff --git a/src/lmstudio/history.py b/src/lmstudio/history.py index 548371d..eb1cdaa 100644 --- a/src/lmstudio/history.py +++ b/src/lmstudio/history.py @@ -59,6 +59,7 @@ ToolCallResultDataDict, ToolResultMessage, UserMessage, + UserMessageDict, ) __all__ = [ @@ -84,6 +85,7 @@ "ToolResultMessage", "UserMessage", "UserMessageContent", + "UserMessageDict", ] # A note on terminology: @@ -543,6 +545,7 @@ def _get_file_details(src: LocalFileInput) -> Tuple[str, bytes]: try: data = src.read() except OSError as exc: + # Note: OSError details remain available via raised_exc.__context__ err_msg = f"Error while reading {src!r} ({exc!r})" raise LMStudioOSError(err_msg) from None name = getattr(src, "name", str(uuid.uuid4())) @@ -555,6 +558,7 @@ def _get_file_details(src: LocalFileInput) -> Tuple[str, bytes]: try: data = src_path.read_bytes() except OSError as exc: + # Note: OSError details remain available via raised_exc.__context__ err_msg = f"Error while reading {str(src_path)!r} ({exc!r})" raise LMStudioOSError(err_msg) from None name = str(src_path.name) diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index b95c2fd..487a83a 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1613,19 +1613,26 @@ def endpoint(self) -> TEndpoint: def get_creation_message(self) -> DictObject: """Get the message to send to create this channel.""" endpoint = self._endpoint - return { + message = { "type": "channelCreate", "endpoint": endpoint.api_endpoint, "channelId": self._channel_id, - "creationParameter": endpoint.creation_params, } + creation_params = endpoint.creation_params + if creation_params: + message["creationParameter"] = creation_params + return message def get_cancel_message(self) -> DictObject: """Get the message to send to cancel this channel.""" + return self.wrap_message({"type": "cancel"}) + + def wrap_message(self, message: DictObject) -> DictObject: + """Wrap a message for sending on this channel.""" return { "type": "channelSend", "channelId": self._channel_id, - "message": {"type": "cancel"}, + "message": message, } # This runs in the context of the background demultiplexing task @@ -1838,24 +1845,30 @@ def __init__(self, api_host: str | None = None) -> None: self._auth_details = self._create_auth_message() @staticmethod - def _create_auth_message() -> DictObject: + def _format_auth_message( + client_id: str | None = None, client_key: str | None = None + ) -> DictObject: """Create an LM Studio websocket authentication message.""" # Note: authentication (in its current form) is primarily a cooperative # resource management mechanism that allows the server to appropriately # manage client-scoped resources (such as temporary file handles). # As such, the client ID and client passkey are currently more a two part # client identifier than they are an adversarial security measure. This is - # sufficient to prevent accidential conflicts and, in combination with secure + # sufficient to prevent accidental conflicts and, in combination with secure # websocket support, would be sufficient to ensure that access to the running # client was required to extract the auth details. - client_identifier = str(uuid.uuid4()) - client_passkey = str(uuid.uuid4()) + client_identifier = client_id if client_id is not None else str(uuid.uuid4()) + client_passkey = client_key if client_key is not None else str(uuid.uuid4()) return { "authVersion": 1, "clientIdentifier": client_identifier, "clientPasskey": client_passkey, } + def _create_auth_message(self) -> DictObject: + """Create an LM Studio websocket authentication message.""" + return self._format_auth_message() + TClient = TypeVar("TClient", bound=ClientBase) diff --git a/src/lmstudio/plugin/__init__.py b/src/lmstudio/plugin/__init__.py new file mode 100644 index 0000000..0c75e92 --- /dev/null +++ b/src/lmstudio/plugin/__init__.py @@ -0,0 +1,65 @@ +"""Support for implementing LM Studio plugins in Python.""" + +# Using wildcard imports to export API symbols is acceptable +# ruff: noqa: F403 + +from .sdk_api import * +from .config_schemas import * +from .hooks import * +from .runner import * + +# Initial Python plugin SDK TODO list +# +# General tasks +# * [DONE] refactor hook channel and controller definitions out to a submodule +# * [DONE] refactor hook registration to be data driven instead of hardcoded in the runner +# * refactor to allow "Abort" request handling to be common across hook invocation tasks +# * refactor to allow hook invocation error handling to be common across hook invocation tasks +# * [DONE] gracefully handle app termination while a dev plugin is still running +# * [DONE] gracefully handle using Ctrl-C to terminate a running dev plugin +# +# Controller APIs (may be limited to relevant hook controllers) +# +# * [DONE] status blocks (both simple "done" blocks, and blocks with in-place updates) +# * citation blocks +# * debug info blocks +# * tool status reporting +# * full content block display +# * chat history retrieval +# * model handle retrieval +# * UI block sender name configuration +# * interactive tool call request confirmation +# +# Prompt preprocessing hook +# * [DONE] emit a status notification block when the demo plugin fires +# * [DONE] add a global plugin config to control the in-place status update demo +# * [DONE] handle "Abort" requests from server (including sending "Aborted" responses) +# * [DONE] catch hook invocation failures and send "Error" responses +# * [DONE] this includes adding runtime checks for the hook returning the wrong type +# +# Token generator hook +# * add an example plugin for this (probably proxying a remote LM Studio instance) +# * define the channel, hook invocation task and hook invocation controller for this hook +# * main request initiation message is "Generate" +# * handle "Abort" requests from server (including sending "Aborted" responses) +# * add controller API for fragment generation +# * add controller API for tool call generation +# * add controller API to indicate when token generation for a given request is completed (or failed) +# * catch hook invocation failures and send "Error" responses +# +# Tools provider hook +# * add example plugin or plugins for this (probably both dice rolling and Wikipedia lookup) +# * define the channel, hook invocation task and hook invocation controller for this hook +# * main request initiation message is "InitSession" (with Initialized/Failed responses) +# * handle "Abort" requests from server (including sending "Aborted" responses) +# * handle "CallTool" requests from server (including sending "CallComplete"/"CallError" response) +# * handle "DiscardSession" requests from server +# * add controller API for tool call status and warning reporting +# +# Plugin config field definitions +# * define approach for specifying plugin config field constraints and style options (e.g. numeric sliders) +# * [usable] numeric: https://github.com/lmstudio-ai/lmstudio-js/blob/main/packages/lms-kv-config/src/valueTypes.ts#L99 +# * [usable] string +# * [usable] boolean +# * select (array of strings, or value/label string pairs) +# * string array diff --git a/src/lmstudio/plugin/__main__.py b/src/lmstudio/plugin/__main__.py new file mode 100644 index 0000000..67b5337 --- /dev/null +++ b/src/lmstudio/plugin/__main__.py @@ -0,0 +1,9 @@ +"""Allow execution of this subpackage as a script.""" + +import sys + +from .cli import main + +# Handle multiprocessing potentially re-running this module with a name other than `__main__` +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/lmstudio/plugin/_dev_runner.py b/src/lmstudio/plugin/_dev_runner.py new file mode 100644 index 0000000..f6c279b --- /dev/null +++ b/src/lmstudio/plugin/_dev_runner.py @@ -0,0 +1,164 @@ +"""Plugin dev client implementation.""" + +import asyncio +import os +import subprocess +import sys + +from contextlib import asynccontextmanager +from pathlib import Path +from functools import partial +from typing import AsyncGenerator, Iterable, TypeAlias + +from typing_extensions import ( + # Native in 3.11+ + assert_never, +) + + +from .runner import ( + ENV_CLIENT_ID, + ENV_CLIENT_KEY, + PluginClient, +) +from ..schemas import DictObject +from ..json_api import ( + ChannelCommonRxEvent, + ChannelEndpoint, + ChannelFinishedEvent, + ChannelRxEvent, + LMStudioChannelClosedError, +) +from .._sdk_models import ( + # TODO: Define aliases at schema generation time + PluginsChannelRegisterDevelopmentPluginCreationParameter as DevPluginRegistrationRequest, + PluginsChannelRegisterDevelopmentPluginCreationParameterDict as DevPluginRegistrationRequestDict, + PluginsChannelRegisterDevelopmentPluginToServerPacketEndDict as DevPluginRegistrationEndDict, +) + + +class DevPluginRegistrationReadyEvent(ChannelRxEvent[None]): + pass + + +DevPluginRegistrationRxEvent: TypeAlias = ( + DevPluginRegistrationReadyEvent | ChannelCommonRxEvent +) + + +class DevPluginRegistrationEndpoint( + ChannelEndpoint[ + tuple[str, str], DevPluginRegistrationRxEvent, DevPluginRegistrationRequestDict + ] +): + """API channel endpoint to register a development plugin and receive credentials.""" + + _API_ENDPOINT = "registerDevelopmentPlugin" + _NOTICE_PREFIX = "Register development plugin" + + def __init__(self, owner: str, name: str) -> None: + # TODO: Set "python" as the type once LM Studio supports that + params = DevPluginRegistrationRequest._from_api_dict( + { + "manifest": { + "type": "plugin", + "runner": "node", + "owner": owner, + "name": name, + } + } + ) + super().__init__(params) + + def iter_message_events( + self, contents: DictObject | None + ) -> Iterable[DevPluginRegistrationRxEvent]: + match contents: + case None: + raise LMStudioChannelClosedError( + "Server failed to complete development plugin registration." + ) + case { + "type": "ready", + "clientIdentifier": str(client_id), + "clientPasskey": str(client_key), + }: + yield self._set_result((client_id, client_key)) + case unmatched: + self.report_unknown_message(unmatched) + + def handle_rx_event(self, event: DevPluginRegistrationRxEvent) -> None: + match event: + case DevPluginRegistrationReadyEvent(_): + pass + case ChannelFinishedEvent(_): + pass + case _: + assert_never(event) + + +class DevPluginClient(PluginClient): + def _get_registration_endpoint(self) -> DevPluginRegistrationEndpoint: + return DevPluginRegistrationEndpoint(self.owner, self.name) + + @asynccontextmanager + async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]: + """Register a dev plugin on entry, deregister it on exit.""" + endpoint = self._get_registration_endpoint() + async with self.plugins._create_channel(endpoint) as channel: + registration_result = await channel.wait_for_result() + try: + yield registration_result + finally: + message: DevPluginRegistrationEndDict = {"type": "end"} + await channel.send_message(message) + + async def run_plugin( + self, *, allow_local_imports: bool = True, debug: bool = False + ) -> int: + if not allow_local_imports: + raise ValueError("Local imports are always permitted for dev plugins") + async with self.register_dev_plugin() as (client_id, client_key): + result = await asyncio.to_thread( + partial( + _run_plugin_in_child_process, + self._plugin_path, + client_id, + client_key, + debug, + ) + ) + return result.returncode + + +# TODO: support the same source code change monitoring features as `lms dev` +def _run_plugin_in_child_process( + plugin_path: Path, client_id: str, client_key: str, debug: bool = False +) -> subprocess.CompletedProcess[str]: + env = os.environ.copy() + env[ENV_CLIENT_ID] = client_id + env[ENV_CLIENT_KEY] = client_key + package_name = __spec__.parent + assert package_name is not None + debug_option = ("--debug",) if debug else () + command: list[str] = [ + sys.executable, + "-m", + package_name, + *debug_option, + os.fspath(plugin_path), + ] + return subprocess.run(command, text=True, env=env) + + +async def run_plugin_async( + plugin_dir: str | os.PathLike[str], *, debug: bool = False +) -> int: + """Asynchronously execute a plugin in development mode.""" + async with DevPluginClient(plugin_dir) as dev_client: + return await dev_client.run_plugin(debug=debug) + + +def run_plugin(plugin_dir: str | os.PathLike[str], *, debug: bool = False) -> int: + """Execute a plugin in development mode.""" + return asyncio.run(run_plugin_async(plugin_dir, debug=debug)) diff --git a/src/lmstudio/plugin/cli.py b/src/lmstudio/plugin/cli.py new file mode 100644 index 0000000..a528528 --- /dev/null +++ b/src/lmstudio/plugin/cli.py @@ -0,0 +1,63 @@ +"""Command line interface implementation.""" + +import argparse +import logging +import os.path +import sys +import warnings + +from typing import Sequence + +from ..sdk_api import sdk_public_api + +from . import _dev_runner, runner + + +def _parse_args( + argv: Sequence[str] | None = None, +) -> tuple[argparse.ArgumentParser, argparse.Namespace]: + py_name = os.path.basename(sys.executable).removesuffix(".exe") + parser = argparse.ArgumentParser( + prog=f"{py_name} -m {__spec__.parent}", + description="LM Studio plugin runner for Python plugins", + ) + parser.add_argument( + "plugin_path", metavar="PLUGIN_PATH", help="Directory name of plugin to run" + ) + parser.add_argument("--dev", action="store_true", help="Run in development mode") + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + return parser, parser.parse_args(argv) + + +@sdk_public_api() +def main(argv: Sequence[str] | None = None) -> int: + """Run the ``lmstudio.plugin`` CLI. + + If *args* is not given, defaults to using ``sys.argv``. + """ + parser, args = _parse_args(argv) + plugin_path = args.plugin_path + if not os.path.exists(plugin_path): + parser.print_usage() + print(f"ERROR: Failed to find plugin folder at {plugin_path!r}") + return 1 + warnings.filterwarnings( + "ignore", ".*the plugin API is not yet stable", FutureWarning + ) + warnings.filterwarnings( + "ignore", ".*the async API is not yet stable", FutureWarning + ) + log_level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig(level=log_level) + if not args.dev: + try: + runner.run_plugin(plugin_path, allow_local_imports=True) + except KeyboardInterrupt: + print("Plugin execution terminated with Ctrl-C") + else: + # Retrieve args from API host, spawn plugin in subprocess + try: + _dev_runner.run_plugin(plugin_path, debug=args.debug) + except KeyboardInterrupt: + pass # Subprocess handles reporting the plugin termination + return 0 diff --git a/src/lmstudio/plugin/config_schemas.py b/src/lmstudio/plugin/config_schemas.py new file mode 100644 index 0000000..993278e --- /dev/null +++ b/src/lmstudio/plugin/config_schemas.py @@ -0,0 +1,235 @@ +"""Define plugin config schemas.""" + +from dataclasses import Field, dataclass, fields +from typing import ( + Any, + ClassVar, + Generic, + Sequence, + TypeVar, + cast, +) + +from typing_extensions import ( + # Native in 3.11+ + Self, + dataclass_transform, +) + +from .._kv_config import dict_from_kvconfig +from .._sdk_models import ( + SerializedKVConfigSchematics, + SerializedKVConfigSchematicsField, + SerializedKVConfigSettings, +) +from ..sdk_api import sdk_public_api +from .sdk_api import LMStudioPluginInitError, plugin_sdk_type + +# Available as lmstudio.plugin.* +__all__ = [ + "BaseConfigSchema", + "config_field", +] + +_T = TypeVar("_T") + +_CONFIG_FIELDS_KEY = "__kv_config_fields__" + + +@dataclass(frozen=True, slots=True) +class _ConfigField(Generic[_T]): + """Plugin config field specification.""" + + label: str + hint: str + + @property + def default(self) -> _T: + """The default value for this config field.""" + # Defaults must be static values, as the UI isn't directly running any plugin code + raise NotImplementedError + + # This never actually gets called (as __set_name__ switches to the default value at runtime) + # However, it's here so type checkers accept field definitions as matching their value type + def __get__( + self, _obj: "BaseConfigSchema | None", _obj_type: type["BaseConfigSchema"] + ) -> _T: + return self.default + + def __set_name__(self, obj_type: type["BaseConfigSchema"], name: str) -> None: + if not issubclass(obj_type, BaseConfigSchema): + msg = f"Plugin config fields must be defined on {BaseConfigSchema.__name__} instances" + raise LMStudioPluginInitError(msg) + # Append this field to the config fields for this schema, creating the list if necessary + config_fields: list[SerializedKVConfigSchematicsField] + config_fields = obj_type.__dict__.get(_CONFIG_FIELDS_KEY, None) + if config_fields is None: + # First config field defined for this schema, so create the config field list + config_fields = [] + try: + inherited_fields = getattr(obj_type, _CONFIG_FIELDS_KEY) + except AttributeError: + pass + else: + # Any inherited fields are included first + config_fields.extend(inherited_fields) + setattr(obj_type, _CONFIG_FIELDS_KEY, config_fields) + config_fields.append(self._to_kv_field(name)) + # Replace the UI config field spec with a regular dataclass default value + setattr(obj_type, name, self.default) + + def _to_kv_field(self, name: str) -> SerializedKVConfigSchematicsField: + raise NotImplementedError + + +@dataclass(frozen=True, slots=True) +class _ConfigBool(_ConfigField[bool]): + """Boolean config field.""" + + default: bool + + def _to_kv_field(self, name: str) -> SerializedKVConfigSchematicsField: + return SerializedKVConfigSchematicsField( + short_key=name, + full_key=name, + type_key="boolean", + type_params={ + "displayName": self.label, + "hint": self.hint, + }, + default_value=self.default, + ) + + +@dataclass(frozen=True, slots=True) +class _ConfigInt(_ConfigField[int]): + """Integer config field.""" + + default: int + + def _to_kv_field(self, name: str) -> SerializedKVConfigSchematicsField: + return SerializedKVConfigSchematicsField( + short_key=name, + full_key=name, + type_key="numeric", + type_params={ + "displayName": self.label, + "hint": self.hint, + "int": True, + }, + default_value=self.default, + ) + + +@dataclass(frozen=True, slots=True) +class _ConfigFloat(_ConfigField[float]): + """Floating point config field.""" + + default: float + + def _to_kv_field(self, name: str) -> SerializedKVConfigSchematicsField: + return SerializedKVConfigSchematicsField( + short_key=name, + full_key=name, + type_key="numeric", + type_params={ + "displayName": self.label, + "hint": self.hint, + "int": False, + }, + default_value=self.default, + ) + + +@dataclass(frozen=True, slots=True) +class _ConfigString(_ConfigField[str]): + """String config field.""" + + default: str + + def _to_kv_field(self, name: str) -> SerializedKVConfigSchematicsField: + return SerializedKVConfigSchematicsField( + short_key=name, + full_key=name, + type_key="string", + type_params={ + "displayName": self.label, + "hint": self.hint, + }, + default_value=self.default, + ) + + +@sdk_public_api() +def config_field(*, label: str, hint: str, default: _T) -> _T: + """Define a plugin config field to be displayed and updated via the app UI.""" + # This type hint intentionally doesn't match the actual returned type + # (the relevant ConfigField[_T] subclass). This is to ensure that + # type checkers will accept config field initialisations like + # "attr: int = config_field(...)". + descriptor: _ConfigField[Any] + match default: + case bool(): + descriptor = _ConfigBool(label, hint, default) + case float(): + descriptor = _ConfigFloat(label, hint, default) + case int(): + descriptor = _ConfigInt(label, hint, default) + case str(): + descriptor = _ConfigString(label, hint, default) + case _: + msg = f"Unsupported type for plugin config field: {type(default)!r}" + raise LMStudioPluginInitError(msg) + return cast(_T, descriptor) + + +# TODO: Cover additional config field types +# TODO: Allow optional constraints and UI display features +# (the similarity of the _to_kv_field methods will reduce when this is done) + + +class _ImplicitDataClass(type): + def __new__( + meta_cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> type: + cls: type = super().__new__(meta_cls, name, bases, namespace) + return dataclass()(cls) + + +@dataclass_transform(field_specifiers=(config_field,)) +@plugin_sdk_type +class BaseConfigSchema(metaclass=_ImplicitDataClass): + """Base class for plugin configuration schema definitions.""" + + # This uses the custom metaclass to automatically make subclasses data classes + # Declare that behaviour in a way that mypy will fully accept + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] + + # ConfigField.__set_name__ lazily creates this class variable + __kv_config_fields__: ClassVar[Sequence[SerializedKVConfigSchematicsField]] + + @classmethod + def _to_kv_config_schematics(cls) -> SerializedKVConfigSchematics | None: + """Convert to wire format for transmission to the app server.""" + try: + config_fields = cls.__kv_config_fields__ + except AttributeError: + # No config fields have been defined on this config schema + # This is fine (as it allows for placeholders in code skeletons) + return None + return SerializedKVConfigSchematics(fields=config_fields) + + @classmethod + def _default_config(cls) -> dict[str, Any]: + default_config: dict[str, Any] = {} + config_spec = cls() + for field in fields(config_spec): + attr = field.name + default_config[attr] = getattr(config_spec, attr) + return default_config + + @classmethod + def _parse(cls, dynamic_config: SerializedKVConfigSettings) -> Self: + config = cls._default_config() + config.update(dict_from_kvconfig(dynamic_config)) + return cls(**config) diff --git a/src/lmstudio/plugin/hooks/__init__.py b/src/lmstudio/plugin/hooks/__init__.py new file mode 100644 index 0000000..e2cae58 --- /dev/null +++ b/src/lmstudio/plugin/hooks/__init__.py @@ -0,0 +1,15 @@ +"""Invoking and supporting plugin hook implementations.""" +# Using wildcard imports to export API symbols is acceptable +# ruff: noqa: F403, F405 + +from .common import * +from .prompt_preprocessor import * +from .token_generator import * +from .tools_provider import * + +# Available as lmstudio.plugin.* +__all__ = [ + "PromptPreprocessorController", + "TokenGeneratorController", + "ToolsProviderController", +] diff --git a/src/lmstudio/plugin/hooks/common.py b/src/lmstudio/plugin/hooks/common.py new file mode 100644 index 0000000..8ed4d35 --- /dev/null +++ b/src/lmstudio/plugin/hooks/common.py @@ -0,0 +1,140 @@ +"""Common utilities to invoke and support plugin hook implementations.""" + +import asyncio + +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from pathlib import Path +from random import randrange +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Generic, + TypeAlias, + TypeVar, +) + +from anyio import move_on_after + +from ...async_api import AsyncSession +from ...schemas import DictObject +from ..._sdk_models import ( + # TODO: Define aliases at schema generation time + PluginsChannelSetGeneratorToClientPacketGenerate as TokenGenerationRequest, + PluginsChannelSetToolsProviderToClientPacketInitSession as ProvideToolsInitSession, + PromptPreprocessingRequest, + SerializedKVConfigSettings, + StatusStepStatus, +) +from ..config_schemas import BaseConfigSchema + +# Available as lmstudio.plugin.hooks.* +__all__ = [ + "AsyncSessionPlugins", + "TPluginConfigSchema", + "TGlobalConfigSchema", +] + + +class AsyncSessionPlugins(AsyncSession): + """Async client session for the plugins namespace.""" + + API_NAMESPACE = "plugins" + + +PluginRequest: TypeAlias = ( + PromptPreprocessingRequest | TokenGenerationRequest | ProvideToolsInitSession +) +TPluginRequest = TypeVar("TPluginRequest", bound=PluginRequest) +TPluginConfigSchema = TypeVar("TPluginConfigSchema", bound=BaseConfigSchema) +TGlobalConfigSchema = TypeVar("TGlobalConfigSchema", bound=BaseConfigSchema) +TConfig = TypeVar("TConfig", bound=BaseConfigSchema) +SendMessageCallback: TypeAlias = Callable[[DictObject], Awaitable[Any]] + + +class ServerRequestError(RuntimeError): + """Plugin received an invalid request from the API server.""" + + +class HookController(Generic[TPluginRequest, TPluginConfigSchema, TGlobalConfigSchema]): + """Common base class for plugin hook API access controllers.""" + + def __init__( + self, + session: AsyncSessionPlugins, + request: TPluginRequest, + plugin_config_schema: type[TPluginConfigSchema], + global_config_schema: type[TGlobalConfigSchema], + ) -> None: + """Initialize common hook controller settings.""" + self.session = session + self.request = request + self.plugin_config = self._parse_config( + request.plugin_config, plugin_config_schema + ) + self.global_config = self._parse_config( + request.global_plugin_config, global_config_schema + ) + work_dir = request.working_directory_path + self.working_path = Path(work_dir) if work_dir else None + + @classmethod + def _parse_config( + cls, config: SerializedKVConfigSettings, schema: type[TConfig] + ) -> TConfig: + if schema is None: + schema = BaseConfigSchema + return schema._parse(config) + + @classmethod + def _create_ui_block_id(self) -> str: + return f"{datetime.now(timezone.utc)}-{randrange(0, 2**32):08x}" + + +StatusUpdateCallback: TypeAlias = Callable[[str, StatusStepStatus, str], Any] + + +class StatusBlockController: + """API access to update a UI status block in-place.""" + + def __init__( + self, + block_id: str, + update_ui: StatusUpdateCallback, + ) -> None: + """Initialize status block controller.""" + self._id = block_id + self._update_ui = update_ui + + async def notify_waiting(self, message: str) -> None: + """Report task is waiting (static icon) in the status block.""" + await self._update_ui(self._id, "waiting", message) + + async def notify_working(self, message: str) -> None: + """Report task is working (dynamic icon) in the status block.""" + await self._update_ui(self._id, "loading", message) + + async def notify_error(self, message: str) -> None: + """Report task error in the status block.""" + await self._update_ui(self._id, "error", message) + + async def notify_canceled(self, message: str) -> None: + """Report task cancellation in the status block.""" + await self._update_ui(self._id, "canceled", message) + + async def notify_done(self, message: str) -> None: + """Report task completion in the status block.""" + await self._update_ui(self._id, "done", message) + + @asynccontextmanager + async def notify_aborted(self, message: str) -> AsyncIterator[None]: + """Report asyncio.CancelledError as cancellation in the status block.""" + try: + yield + except asyncio.CancelledError: + # Allow the notification to be sent, but don't necessarily wait for the reply + with move_on_after(0.2, shield=True): + await self.notify_canceled(message) + raise diff --git a/src/lmstudio/plugin/hooks/prompt_preprocessor.py b/src/lmstudio/plugin/hooks/prompt_preprocessor.py new file mode 100644 index 0000000..1432194 --- /dev/null +++ b/src/lmstudio/plugin/hooks/prompt_preprocessor.py @@ -0,0 +1,384 @@ +"""Invoking and supporting prompt preprocessor hook implementations.""" + +import asyncio + +from contextlib import asynccontextmanager +from dataclasses import dataclass +from traceback import format_tb +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Generic, + Iterable, + TypeAlias, +) +from typing_extensions import ( + # Native in 3.11+ + assert_never, +) + +from anyio import create_task_group +from anyio.abc import TaskGroup + +from ..._logging import new_logger +from ...schemas import DictObject, EmptyDict, ValidationError +from ...history import UserMessage, UserMessageDict +from ...json_api import ( + ChannelCommonRxEvent, + ChannelEndpoint, + ChannelFinishedEvent, + ChannelRxEvent, + load_struct, +) +from ..._sdk_models import ( + PluginsRpcProcessingHandleUpdateParameter, + ProcessingUpdate, + ProcessingUpdateStatusCreate, + ProcessingUpdateStatusUpdate, + PromptPreprocessingAbortedDict, + PromptPreprocessingCompleteDict, + PromptPreprocessingErrorDict, + PromptPreprocessingRequest, + SerializedLMSExtendedErrorDict, + StatusStepState, + StatusStepStatus, +) +from ..config_schemas import BaseConfigSchema +from .common import ( + AsyncSessionPlugins, + HookController, + SendMessageCallback, + ServerRequestError, + StatusBlockController, + TPluginConfigSchema, + TGlobalConfigSchema, +) + +# Available as lmstudio.plugin.hooks.* +__all__ = [ + "PromptPreprocessorController", + "PromptPreprocessorHook", + "run_prompt_preprocessor", +] + + +class PromptPreprocessingAbortEvent(ChannelRxEvent[str]): + pass + + +class PromptPreprocessingRequestEvent(ChannelRxEvent[PromptPreprocessingRequest]): + pass + + +PromptPreprocessingRxEvent: TypeAlias = ( + PromptPreprocessingAbortEvent + | PromptPreprocessingRequestEvent + | ChannelCommonRxEvent +) + + +class PromptPreprocessingEndpoint( + ChannelEndpoint[tuple[str, str], PromptPreprocessingRxEvent, EmptyDict] +): + """API channel endpoint to accept prompt preprocessing requests.""" + + _API_ENDPOINT = "setPromptPreprocessor" + _NOTICE_PREFIX = "Prompt preprocessing" + + def __init__(self) -> None: + super().__init__({}) + + def iter_message_events( + self, contents: DictObject | None + ) -> Iterable[PromptPreprocessingRxEvent]: + match contents: + case None: + # Server can only terminate the link by closing the websocket + pass + case {"type": "abort", "taskId": str(task_id)}: + yield PromptPreprocessingAbortEvent(task_id) + case {"type": "preprocess"} as request_dict: + parsed_request = PromptPreprocessingRequest._from_any_api_dict( + request_dict + ) + yield PromptPreprocessingRequestEvent(parsed_request) + case unmatched: + self.report_unknown_message(unmatched) + + def handle_rx_event(self, event: PromptPreprocessingRxEvent) -> None: + match event: + case PromptPreprocessingAbortEvent(task_id): + self._logger.debug(f"Aborting {task_id}", task_id=task_id) + case PromptPreprocessingRequestEvent(request): + task_id = request.task_id + self._logger.debug( + "Received prompt preprocessing request", task_id=task_id + ) + case ChannelFinishedEvent(_): + pass + case _: + assert_never(event) + + +class PromptPreprocessorController( + HookController[PromptPreprocessingRequest, TPluginConfigSchema, TGlobalConfigSchema] +): + """API access for prompt preprocessor hook implementations.""" + + def __init__( + self, + session: AsyncSessionPlugins, + request: PromptPreprocessingRequest, + plugin_config_schema: type[TPluginConfigSchema], + global_config_schema: type[TGlobalConfigSchema], + ) -> None: + """Initialize prompt preprocessor hook controller.""" + super().__init__(session, request, plugin_config_schema, global_config_schema) + self.task_id = request.task_id + self.pci = request.pci + self.token = request.token + + async def _send_handle_update(self, update: ProcessingUpdate) -> Any: + handle_update = PluginsRpcProcessingHandleUpdateParameter( + pci=self.pci, + token=self.token, + update=update, + ) + return await self.session.remote_call("processingHandleUpdate", handle_update) + + async def _create_status_block( + self, block_id: str, status: StatusStepStatus, message: str + ) -> None: + await self._send_handle_update( + ProcessingUpdateStatusCreate( + id=block_id, + state=StatusStepState( + status=status, + text=message, + ), + ), + ) + + async def _send_status_update( + self, block_id: str, status: StatusStepStatus, message: str + ) -> None: + await self._send_handle_update( + ProcessingUpdateStatusUpdate( + id=block_id, + state=StatusStepState( + status=status, + text=message, + ), + ), + ) + + async def notify_start(self, message: str) -> StatusBlockController: + """Report task initiation in a new UI status block, return controller for updates.""" + status_block = StatusBlockController( + self._create_ui_block_id(), + self._send_status_update, + ) + await self._create_status_block(status_block._id, "waiting", message) + return status_block + + async def notify_done(self, message: str) -> None: + """Report task completion in a new UI status block.""" + await self._create_status_block(self._create_ui_block_id(), "done", message) + + +PromptPreprocessorHook = Callable[ + [PromptPreprocessorController[Any, Any], UserMessage], + Awaitable[UserMessage | UserMessageDict | None], +] + + +# TODO: Define a common "PluginHookHandler" base class +@dataclass() +class PromptPreprocessor(Generic[TPluginConfigSchema, TGlobalConfigSchema]): + """Handle accepting prompt preprocessing requests.""" + + plugin_name: str + hook_impl: PromptPreprocessorHook + plugin_config_schema: type[TPluginConfigSchema] + global_config_schema: type[TGlobalConfigSchema] + + def __post_init__(self) -> None: + self._logger = logger = new_logger(__name__) + logger.update_context(plugin_name=self.plugin_name) + self._abort_events: dict[str, asyncio.Event] = {} + + async def process_requests( + self, session: AsyncSessionPlugins, notify_ready: Callable[[], Any] + ) -> None: + """Create plugin channel and wait for server requests.""" + logger = self._logger + endpoint = PromptPreprocessingEndpoint() + async with session._create_channel(endpoint) as channel: + notify_ready() + logger.info("Opened channel to receive prompt preprocessing requests...") + send_message = channel.send_message + async with create_task_group() as tg: + logger.debug("Waiting for prompt preprocessing requests...") + async for contents in channel.rx_stream(): + logger.debug( + f"Handling prompt preprocessing channel message: {contents}" + ) + for event in endpoint.iter_message_events(contents): + logger.debug("Handling prompt preprocessing channel event") + endpoint.handle_rx_event(event) + match event: + case PromptPreprocessingAbortEvent(): + await self._abort_hook_invocation( + event.arg, send_message + ) + case PromptPreprocessingRequestEvent(): + logger.debug( + "Running prompt preprocessing request hook" + ) + ctl = PromptPreprocessorController( + session, + event.arg, + self.plugin_config_schema, + self.global_config_schema, + ) + tg.start_soon(self._invoke_hook, ctl, send_message) + if endpoint.is_finished: + break + + async def _abort_hook_invocation( + self, task_id: str, send_response: SendMessageCallback + ) -> None: + """Abort the specified hook invocation (if it is still running).""" + abort_event = self._abort_events.get(task_id, None) + if abort_event is not None: + abort_event.set() + response = PromptPreprocessingAbortedDict( + type="aborted", + taskId=task_id, + ) + await send_response(response) + + async def _cancel_on_event( + self, tg: TaskGroup, event: asyncio.Event, message: str + ) -> None: + await event.wait() + self._logger.info(message) + tg.cancel_scope.cancel() + + @asynccontextmanager + async def _registered_hook_invocation( + self, task_id: str + ) -> AsyncIterator[asyncio.Event]: + logger = self._logger + abort_events = self._abort_events + if task_id in abort_events: + err_msg = f"Hook invocation already in progress for {task_id}" + raise ServerRequestError(err_msg) + abort_events[task_id] = abort_event = asyncio.Event() + try: + async with create_task_group() as tg: + tg.start_soon( + self._cancel_on_event, + tg, + abort_event, + f"Aborting request {task_id}", + ) + logger.info(f"Processing request {task_id}") + yield abort_event + tg.cancel_scope.cancel() + finally: + abort_events.pop(task_id, None) + if abort_event.is_set(): + completion_message = f"Aborted request {task_id}" + else: + completion_message = f"Processed request {task_id}" + logger.info(completion_message) + + async def _invoke_hook( + self, + ctl: PromptPreprocessorController[TPluginConfigSchema, TGlobalConfigSchema], + send_response: SendMessageCallback, + ) -> None: + logger = self._logger + task_id = ctl.task_id + message = ctl.request.input + error_details: SerializedLMSExtendedErrorDict | None = None + response_dict: UserMessageDict + expected_cls = UserMessage + try: + if not isinstance(message, expected_cls): + err_msg = f"Received {type(message).__name__!r} ({expected_cls.__name__!r} expected)" + raise ServerRequestError(err_msg) + async with self._registered_hook_invocation(task_id) as abort_event: + response = await self.hook_impl(ctl, message) + except Exception as exc: + err_msg = "Error calling prompt preprocessing hook" + logger.error(err_msg, exc_info=True, exc=repr(exc)) + # TODO: Determine if it's worth sending the stack trace to the server + ui_cause = f"{err_msg}\n({type(exc).__name__}: {exc})" + error_details = SerializedLMSExtendedErrorDict( + cause=ui_cause, stack="\n".join(format_tb(exc.__traceback__)) + ) + else: + if abort_event.is_set(): + # Processing was aborted by the server, skip sending a response + return + if response is None: + logger.debug("No changes made to preprocessed prompt") + response_dict = message.to_dict() + else: + logger.debug( + "Validating prompt preprocessing response", response=response + ) + if isinstance(response, dict): + try: + parsed_response = load_struct(response, expected_cls) + except ValidationError as exc: + err_msg = f"Failed to parse prompt preprocessing response as {expected_cls.__name__}\n({exc})" + logger.error(err_msg) + error_details = SerializedLMSExtendedErrorDict(cause=err_msg) + else: + response_dict = parsed_response.to_dict() + elif isinstance(response, UserMessage): + response_dict = response.to_dict() + else: + err_msg = f"Prompt preprocessing hook returned {type(response).__name__!r} ({expected_cls.__name__!r} expected)" + logger.error(err_msg) + error_details = SerializedLMSExtendedErrorDict(cause=err_msg) + channel_message: DictObject + if error_details is not None: + error_title = f"Prompt preprocessing error in plugin {self.plugin_name!r}" + common_error_args: SerializedLMSExtendedErrorDict = { + "title": error_title, + "rootTitle": error_title, + } + error_details.update(common_error_args) + channel_message = PromptPreprocessingErrorDict( + type="error", + taskId=task_id, + error=error_details, + ) + else: + channel_message = PromptPreprocessingCompleteDict( + type="complete", + taskId=task_id, + processed=response_dict, + ) + await send_response(channel_message) + + +async def run_prompt_preprocessor( + plugin_name: str, + hook_impl: PromptPreprocessorHook, + plugin_config_schema: type[BaseConfigSchema], + global_config_schema: type[BaseConfigSchema], + session: AsyncSessionPlugins, + notify_ready: Callable[[], Any], +) -> None: + """Accept prompt preprocessing requests.""" + prompt_preprocessor = PromptPreprocessor( + plugin_name, hook_impl, plugin_config_schema, global_config_schema + ) + await prompt_preprocessor.process_requests(session, notify_ready) diff --git a/src/lmstudio/plugin/hooks/token_generator.py b/src/lmstudio/plugin/hooks/token_generator.py new file mode 100644 index 0000000..51bb5ce --- /dev/null +++ b/src/lmstudio/plugin/hooks/token_generator.py @@ -0,0 +1,43 @@ +"""Invoking and supporting token generator hook implementations.""" + +from typing import Any, Awaitable, Callable +from ..._sdk_models import ( + # TODO: Define aliases at schema generation time + PluginsChannelSetGeneratorToClientPacketGenerate as TokenGenerationRequest, +) + +from ..config_schemas import BaseConfigSchema +from .common import ( + AsyncSessionPlugins, + HookController, + TPluginConfigSchema, + TGlobalConfigSchema, +) + +# Available as lmstudio.plugin.hooks.* +__all__ = [ + "TokenGeneratorController", + "TokenGeneratorHook", + "run_token_generator", +] + + +class TokenGeneratorController( + HookController[TokenGenerationRequest, TPluginConfigSchema, TGlobalConfigSchema] +): + """API access for token generator hook implementations.""" + + +TokenGeneratorHook = Callable[[TokenGeneratorController[Any, Any]], Awaitable[None]] + + +async def run_token_generator( + plugin_name: str, + hook_impl: TokenGeneratorHook, + plugin_config_schema: type[BaseConfigSchema], + global_config_schema: type[BaseConfigSchema], + session: AsyncSessionPlugins, + notify_ready: Callable[[], Any], +) -> None: + """Accept token generation requests.""" + raise NotImplementedError diff --git a/src/lmstudio/plugin/hooks/tools_provider.py b/src/lmstudio/plugin/hooks/tools_provider.py new file mode 100644 index 0000000..e43b096 --- /dev/null +++ b/src/lmstudio/plugin/hooks/tools_provider.py @@ -0,0 +1,50 @@ +"""Invoking and supporting tools provider hook implementations.""" + +from typing import Any, Awaitable, Callable, Iterable + +from ...json_api import ( + ToolDefinition, +) + +from ..._sdk_models import ( + # TODO: Define aliases at schema generation time + PluginsChannelSetToolsProviderToClientPacketInitSession as ProvideToolsInitSession, +) + +from ..config_schemas import BaseConfigSchema +from .common import ( + AsyncSessionPlugins, + HookController, + TPluginConfigSchema, + TGlobalConfigSchema, +) + +# Available as lmstudio.plugin.hooks.* +__all__ = [ + "ToolsProviderController", + "ToolsProviderHook", + "run_tools_provider", +] + + +class ToolsProviderController( + HookController[ProvideToolsInitSession, TPluginConfigSchema, TGlobalConfigSchema] +): + """API access for tools provider hook implementations.""" + + +ToolsProviderHook = Callable[ + [ToolsProviderController[Any, Any]], Awaitable[Iterable[ToolDefinition]] +] + + +async def run_tools_provider( + plugin_name: str, + hook_impl: ToolsProviderHook, + plugin_config_schema: type[BaseConfigSchema], + global_config_schema: type[BaseConfigSchema], + session: AsyncSessionPlugins, + notify_ready: Callable[[], Any], +) -> None: + """Accept tools provider session requests.""" + raise NotImplementedError diff --git a/src/lmstudio/plugin/runner.py b/src/lmstudio/plugin/runner.py new file mode 100644 index 0000000..fa81ba7 --- /dev/null +++ b/src/lmstudio/plugin/runner.py @@ -0,0 +1,263 @@ +"""Plugin API client implementation.""" + +# Plugins are expected to maintain multiple concurrently open channels and handle +# multiple concurrent server requests, so plugin implementations are always async + +import asyncio +import json +import os +import runpy +import sys +import warnings + +from functools import partial +from pathlib import Path +from typing import Any, Awaitable, Callable, TypeAlias, TypeVar + +from anyio import create_task_group + +from .._logging import new_logger +from ..sdk_api import LMStudioFileNotFoundError, sdk_public_api +from ..schemas import DictObject +from ..async_api import AsyncClient +from .._sdk_models import ( + PluginsRpcSetConfigSchematicsParameter as SetConfigSchematicsParam, + PluginsRpcSetGlobalConfigSchematicsParameter as SetGlobalConfigSchematicsParam, +) +from .sdk_api import LMStudioPluginInitError, LMStudioPluginRuntimeError +from .config_schemas import BaseConfigSchema +from .hooks import ( + AsyncSessionPlugins, + TPluginConfigSchema, + TGlobalConfigSchema, + run_prompt_preprocessor, + run_token_generator, + run_tools_provider, +) + +# Available as lmstudio.plugin.* +__all__ = [ + "run_plugin", + "run_plugin_async", +] + +# Warn about the plugin API stability, since it is still experimental +_PLUGIN_API_STABILITY_WARNING = """\ +Note the plugin API is not yet stable and may change without notice in future releases +""" + +AnyHookImpl: TypeAlias = Callable[..., Awaitable[Any]] +THookImpl = TypeVar("THookImpl", bound=AnyHookImpl) +ReadyCallback: TypeAlias = Callable[[], Any] +HookRunner: TypeAlias = Callable[ + [ + str, # Plugin name + THookImpl, + type[TPluginConfigSchema], + type[TGlobalConfigSchema], + AsyncSessionPlugins, + ReadyCallback, + ], + Awaitable[Any], +] + +_HOOK_RUNNERS: dict[str, HookRunner[Any, Any, Any]] = { + "preprocess_prompt": run_prompt_preprocessor, + "generate_tokens": run_token_generator, + "list_provided_tools": run_tools_provider, +} + + +class PluginClient(AsyncClient): + def __init__( + self, + plugin_dir: str | os.PathLike[str], + client_id: str | None = None, + client_key: str | None = None, + ) -> None: + warnings.warn(_PLUGIN_API_STABILITY_WARNING, FutureWarning) + self._client_id = client_id + self._client_key = client_key + super().__init__() + # TODO: Consider moving file reading to class method and make this a data class + self._plugin_path = plugin_path = Path(plugin_dir) + manifest_path = plugin_path / "manifest.json" + if not manifest_path.exists(): + raise LMStudioFileNotFoundError(manifest_path) + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + if manifest["type"] != "plugin": + raise LMStudioPluginInitError(f"Invalid manifest type: {manifest['type']}") + if manifest["runner"] != "python": + # This works (even though the app doesn't natively support Python plugins yet), + # as LM Studio doesn't check the runner type when requesting dev credentials. + raise LMStudioPluginInitError( + f'Invalid manifest runner: {manifest["runner"]} (expected "python")' + ) + self.owner = manifest["owner"] + self.name = name = manifest["name"] + self._logger = logger = new_logger(__name__) + logger.update_context(plugin=name) + + _ALL_SESSIONS = ( + # Plugin controller access all runs through a dedicated endpoint + AsyncSessionPlugins, + ) + + def _create_auth_message(self) -> DictObject: + """Create an LM Studio websocket authentication message.""" + if self._client_id is None or self._client_key is None: + return super()._create_auth_message() + # Use plugin credentials to unlock the full plugin client API + return self._format_auth_message(self._client_id, self._client_key) + + @property + def plugins(self) -> AsyncSessionPlugins: + """Return the plugins API client session.""" + return self._get_session(AsyncSessionPlugins) + + async def _run_hook_impl( + self, + hook_runner: HookRunner[THookImpl, TPluginConfigSchema, TGlobalConfigSchema], + hook_impl: THookImpl, + plugin_config_schema: type[TPluginConfigSchema], + global_config_schema: type[TGlobalConfigSchema], + notify_ready: ReadyCallback, + ) -> None: + """Run the given hook implementation.""" + await hook_runner( + self.name, + hook_impl, + plugin_config_schema, + global_config_schema, + self.plugins, + notify_ready, + ) + + _CONFIG_SCHEMA_SCOPES = { + "plugin": ("ConfigSchema", "setConfigSchematics", SetConfigSchematicsParam), + "global": ( + "GlobalConfigSchema", + "setGlobalConfigSchematics", + SetGlobalConfigSchematicsParam, + ), + } + + async def _load_config_schema( + self, ns: DictObject, scope: str + ) -> type[BaseConfigSchema]: + logger = self._logger + config_name, endpoint, param_type = self._CONFIG_SCHEMA_SCOPES[scope] + maybe_config_schema = ns.get(config_name, None) + if maybe_config_schema is None: + # Use an empty config in the client, don't register any schema with the server + logger.debug(f"Plugin does not define {config_name!r}") + return BaseConfigSchema + if not issubclass(maybe_config_schema, BaseConfigSchema): + raise LMStudioPluginInitError( + f"{config_name}: Expected {BaseConfigSchema!r} subclass definition, not {maybe_config_schema!r}" + ) + config_schema: type[BaseConfigSchema] = maybe_config_schema + kv_config_schematics = config_schema._to_kv_config_schematics() + if kv_config_schematics is None: + # No fields to configure, no need to register schema with the server + logger.info(f"Plugin defines an empty {config_name!r}") + else: + # Only notify the server if there is at least one config field defined + logger.info(f"Plugin defines {config_name!r}, sending to server...") + await self.plugins.remote_call( + endpoint, + param_type( + schematics=kv_config_schematics, + ), + ) + return config_schema + + async def run_plugin(self, *, allow_local_imports: bool = False) -> int: + # TODO: Nicer error handling + plugin_path = self._plugin_path + source_dir_path = plugin_path / "src" + source_path = source_dir_path / "plugin.py" + if not source_path.exists(): + raise LMStudioFileNotFoundError(source_path) + # TODO: Consider passing this logger to hook runners (instead of each creating their own) + logger = self._logger + logger.update_context(plugin_name=self.name) + logger.info(f"Running {source_path}") + if allow_local_imports: + # We don't try to revert the path change, as that can have odd side-effects + sys.path.insert(0, str(source_dir_path)) + plugin_ns = runpy.run_path(str(source_path), run_name="__lms_plugin__") + # Look up config schemas in the namespace + plugin_schema = await self._load_config_schema(plugin_ns, "plugin") + global_schema = await self._load_config_schema(plugin_ns, "global") + # Look up hook implementations in the namespace + implemented_hooks: list[Callable[[], Awaitable[Any]]] = [] + hook_ready_events: list[asyncio.Event] = [] + for hook_name, hook_runner in _HOOK_RUNNERS.items(): + hook_impl = plugin_ns.get(hook_name, None) + if hook_impl is None: + logger.debug(f"Plugin does not define the {hook_name!r} hook") + continue + logger.info(f"Plugin defines the {hook_name!r} hook") + hook_ready_event = asyncio.Event() + hook_ready_events.append(hook_ready_event) + implemented_hooks.append( + partial( + self._run_hook_impl, + hook_runner, + hook_impl, + plugin_schema, + global_schema, + hook_ready_event.set, + ) + ) + plugin = self.name + if not implemented_hooks: + hook_list = "\n - ".join(("", *sorted(_HOOK_RUNNERS))) + print( + f"No plugin hooks defined in {plugin!r}, " + f"expected at least one of:{hook_list}" + ) + return 1 + # Use anyio and exceptiongroup to handle the lack of native task + # and exception groups prior to Python 3.11 + async with create_task_group() as tg: + for implemented_hook in implemented_hooks: + tg.start_soon(implemented_hook) + # Should this have a time limit set to guard against SDK bugs? + await asyncio.gather(*(e.wait() for e in hook_ready_events)) + await self.plugins.remote_call("pluginInitCompleted") + # Indicate that prompt processing is ready + print(f"Plugin {plugin!r} running, press Ctrl-C to terminate...") + # Task group will wait for the plugins to run + return 0 + + +ENV_CLIENT_ID = "LMS_PLUGIN_CLIENT_IDENTIFIER" +ENV_CLIENT_KEY = "LMS_PLUGIN_CLIENT_PASSKEY" + + +def get_plugin_credentials_from_env() -> tuple[str, str]: + return os.environ[ENV_CLIENT_ID], os.environ[ENV_CLIENT_KEY] + + +@sdk_public_api() +async def run_plugin_async( + plugin_dir: str | os.PathLike[str], *, allow_local_imports: bool = False +) -> None: + """Asynchronously execute a plugin in development mode.""" + try: + client_id, client_key = get_plugin_credentials_from_env() + except KeyError: + err_msg = f"ERROR: {ENV_CLIENT_ID} and {ENV_CLIENT_KEY} must both be set in the environment" + raise LMStudioPluginRuntimeError(err_msg) + async with PluginClient(plugin_dir, client_id, client_key) as plugin_client: + await plugin_client.run_plugin(allow_local_imports=allow_local_imports) + + +@sdk_public_api() +def run_plugin( + plugin_dir: str | os.PathLike[str], *, allow_local_imports: bool = False +) -> None: + """Execute a plugin in application mode.""" + asyncio.run(run_plugin_async(plugin_dir, allow_local_imports=allow_local_imports)) diff --git a/src/lmstudio/plugin/sdk_api.py b/src/lmstudio/plugin/sdk_api.py new file mode 100644 index 0000000..b1dd267 --- /dev/null +++ b/src/lmstudio/plugin/sdk_api.py @@ -0,0 +1,37 @@ +"""Common definitions for defining the plugin SDK interfaces.""" + +from typing import TypeVar + +from ..sdk_api import LMStudioRuntimeError, LMStudioValueError + +_PLUGIN_SDK_SUBMODULE = ".".join(__name__.split(".")[:-1]) + +_C = TypeVar("_C", bound=type) + +# Available as lmstudio.plugin.* +__all__ = [ + "LMStudioPluginInitError", + "LMStudioPluginRuntimeError", +] + + +def plugin_sdk_type(cls: _C) -> _C: + """Indicates a class forms part of the public plugin SDK boundary. + + Sets `__module__` to the plugin SDK submodule import rather than + leaving it set to the implementation module. + + Note: methods are *not* implicitly decorated as public SDK APIs + """ + cls.__module__ = _PLUGIN_SDK_SUBMODULE + return cls + + +@plugin_sdk_type +class LMStudioPluginRuntimeError(LMStudioRuntimeError): + """Plugin runtime behaviour was not as expected.""" + + +@plugin_sdk_type +class LMStudioPluginInitError(LMStudioValueError): + """Plugin initialization value was not as expected.""" diff --git a/src/lmstudio/schemas.py b/src/lmstudio/schemas.py index 4dcefab..92d4a11 100644 --- a/src/lmstudio/schemas.py +++ b/src/lmstudio/schemas.py @@ -11,6 +11,7 @@ Protocol, Sequence, TypeAlias, + TypedDict, TypeVar, cast, runtime_checkable, @@ -20,16 +21,18 @@ Self, ) -from msgspec import Struct, convert, to_builtins +from msgspec import Struct, ValidationError, convert, to_builtins from msgspec.json import schema from .sdk_api import LMStudioValueError, sdk_public_api, sdk_public_type __all__ = [ + "AnyLMStudioStruct", "BaseModel", "DictObject", "DictSchema", "ModelSchema", + "ValidationError", ] DictObject: TypeAlias = Mapping[str, Any] # Any JSON-compatible string-keyed dict @@ -224,3 +227,15 @@ def __str__(self) -> str: AnyLMStudioStruct = LMStudioStruct[Any] + + +class EmptyStruct(LMStudioStruct["EmptyDict"]): + """LM Studio struct with no defined fields.""" + + pass + + +class EmptyDict(TypedDict): + """Wire format with no defined fields.""" + + pass diff --git a/src/lmstudio/sdk_api.py b/src/lmstudio/sdk_api.py index e559983..149c0e0 100644 --- a/src/lmstudio/sdk_api.py +++ b/src/lmstudio/sdk_api.py @@ -49,6 +49,11 @@ class LMStudioOSError(OSError, LMStudioError): """The SDK received an error while accessing the local operating system.""" +@sdk_public_type +class LMStudioFileNotFoundError(FileNotFoundError, LMStudioError): + """The SDK failed to find the specified file on the local file system.""" + + @sdk_public_type class LMStudioRuntimeError(RuntimeError, LMStudioError): """User requested an invalid sequence of operations from the SDK.""" @@ -69,7 +74,7 @@ def _truncate_traceback(exc: BaseException | None) -> None: return if isinstance(exc, LMStudioError) or not isinstance(exc, Exception): # Truncate the traceback for SDK exceptions at the SDK boundary. - # Also truncate asychronous exceptions like KeyboardInterrupt. + # Also truncate asynchronous exceptions like KeyboardInterrupt. # Other unwrapped exceptions indicate SDK bugs and keep a full traceback. exc.__traceback__ = None diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index bf15038..8a368ef 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -173,6 +173,11 @@ def get_creation_message(self) -> DictObject: """Get the message to send to create this channel.""" return self._api_channel.get_creation_message() + def send_message(self, message: DictObject) -> None: + """Send given message on this channel.""" + wrapped_message = self._api_channel.wrap_message(message) + self._send_json(wrapped_message) + def cancel(self) -> None: """Cancel the channel.""" if self._is_finished: @@ -316,7 +321,7 @@ def _notify_client_termination(self) -> int: for rx_queue in self._mux.all_queues(): rx_queue.put(None) num_clients += 1 - self._logger.info( + self._logger.debug( f"Notified {num_clients} clients of websocket termination", num_clients=num_clients, ) diff --git a/tests/invalid_plugins/hook_exception/manifest.json b/tests/invalid_plugins/hook_exception/manifest.json new file mode 100644 index 0000000..b1c27f4 --- /dev/null +++ b/tests/invalid_plugins/hook_exception/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "hook-exception", + "revision": 1 +} diff --git a/tests/invalid_plugins/hook_exception/src/plugin.py b/tests/invalid_plugins/hook_exception/src/plugin.py new file mode 100644 index 0000000..4904a87 --- /dev/null +++ b/tests/invalid_plugins/hook_exception/src/plugin.py @@ -0,0 +1,8 @@ +# Plugin runner doesn't actually care these type hints are incorrect +# It *does* care that an exception is raised +class ExampleCustomException(Exception): + pass + + +async def preprocess_prompt(_ctl: None, _message: None) -> dict[None, None]: + raise ExampleCustomException("Example plugin hook failure") diff --git a/tests/invalid_plugins/malformed_prompt/manifest.json b/tests/invalid_plugins/malformed_prompt/manifest.json new file mode 100644 index 0000000..7cc8fee --- /dev/null +++ b/tests/invalid_plugins/malformed_prompt/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "malformed-prompt", + "revision": 1 +} diff --git a/tests/invalid_plugins/malformed_prompt/src/plugin.py b/tests/invalid_plugins/malformed_prompt/src/plugin.py new file mode 100644 index 0000000..936519e --- /dev/null +++ b/tests/invalid_plugins/malformed_prompt/src/plugin.py @@ -0,0 +1,4 @@ +# Plugin runner doesn't actually care these type hints are incorrect +# It *does* care that the returned "user message" has no content +async def preprocess_prompt(_ctl: None, _message: None) -> dict[None, None]: + return {} diff --git a/tests/invalid_plugins/no_hooks/manifest.json b/tests/invalid_plugins/no_hooks/manifest.json new file mode 100644 index 0000000..61b879f --- /dev/null +++ b/tests/invalid_plugins/no_hooks/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "no-hooks", + "revision": 1 +} diff --git a/tests/invalid_plugins/no_hooks/src/plugin.py b/tests/invalid_plugins/no_hooks/src/plugin.py new file mode 100644 index 0000000..c16b7c2 --- /dev/null +++ b/tests/invalid_plugins/no_hooks/src/plugin.py @@ -0,0 +1 @@ +# No hooks defined, runner will refuse to execute this plugin diff --git a/tests/test_plugin_config.py b/tests/test_plugin_config.py new file mode 100644 index 0000000..7beaf5b --- /dev/null +++ b/tests/test_plugin_config.py @@ -0,0 +1,112 @@ +"""Test plugin config schema definitions.""" + +from lmstudio.plugin import BaseConfigSchema, config_field + + +def test_empty_config() -> None: + class ConfigSchema(BaseConfigSchema): + pass + + assert ConfigSchema._to_kv_config_schematics() is None + + +def test_config_field_bool() -> None: + class ConfigSchema(BaseConfigSchema): + setting: bool = config_field(label="UI label", hint="UI tooltip", default=True) + + assert ConfigSchema.setting is True + assert ConfigSchema().setting is True + kv_config_schematics = ConfigSchema._to_kv_config_schematics() + assert kv_config_schematics is not None + expected_kv_config_schematics = { + "fields": [ + { + "defaultValue": True, + "fullKey": "setting", + "shortKey": "setting", + "typeKey": "boolean", + "typeParams": { + "displayName": "UI label", + "hint": "UI tooltip", + }, + } + ] + } + assert kv_config_schematics.to_dict() == expected_kv_config_schematics + + +def test_config_field_int() -> None: + class ConfigSchema(BaseConfigSchema): + setting: int = config_field(label="UI label", hint="UI tooltip", default=42) + + assert ConfigSchema.setting == 42 + assert ConfigSchema.setting == 42 + kv_config_schematics = ConfigSchema._to_kv_config_schematics() + assert kv_config_schematics is not None + expected_kv_config_schematics = { + "fields": [ + { + "defaultValue": 42, + "fullKey": "setting", + "shortKey": "setting", + "typeKey": "numeric", + "typeParams": { + "displayName": "UI label", + "hint": "UI tooltip", + "int": True, + }, + } + ] + } + assert kv_config_schematics.to_dict() == expected_kv_config_schematics + + +def test_config_field_float() -> None: + class ConfigSchema(BaseConfigSchema): + setting: float = config_field(label="UI label", hint="UI tooltip", default=4.2) + + assert ConfigSchema.setting == 4.2 + assert ConfigSchema().setting == 4.2 + kv_config_schematics = ConfigSchema._to_kv_config_schematics() + assert kv_config_schematics is not None + expected_kv_config_schematics = { + "fields": [ + { + "defaultValue": 4.2, + "fullKey": "setting", + "shortKey": "setting", + "typeKey": "numeric", + "typeParams": { + "displayName": "UI label", + "hint": "UI tooltip", + "int": False, + }, + } + ] + } + assert kv_config_schematics.to_dict() == expected_kv_config_schematics + + +def test_config_field_str() -> None: + class ConfigSchema(BaseConfigSchema): + setting: str = config_field(label="UI label", hint="UI tooltip", default="text") + + assert ConfigSchema.setting == "text" + assert ConfigSchema().setting == "text" + kv_config_schematics = ConfigSchema._to_kv_config_schematics() + assert kv_config_schematics is not None + expected_kv_config_schematics = { + "fields": [ + { + "defaultValue": "text", + "fullKey": "setting", + "shortKey": "setting", + "typeKey": "string", + "typeParams": { + "displayName": "UI label", + "hint": "UI tooltip", + }, + } + ] + } + assert kv_config_schematics.to_dict() == expected_kv_config_schematics diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 495f4d1..a592437 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -152,7 +152,7 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio async def test_websocket_cm_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - auth_details = AsyncClient._create_auth_message() + auth_details = AsyncClient._format_auth_message() lmsws = AsyncLMStudioWebsocket(f"http://{LOCAL_API_HOST}/system", auth_details) # SDK client websockets start out disconnected assert not lmsws.connected @@ -190,7 +190,7 @@ def ws_thread() -> Generator[AsyncWebsocketThread, None, None]: @pytest.mark.lmstudio def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - auth_details = Client._create_auth_message() + auth_details = Client._format_auth_message() lmsws = SyncLMStudioWebsocket( ws_thread, f"http://{LOCAL_API_HOST}/system", auth_details ) diff --git a/tox.ini b/tox.ini index 0bacd71..642c9cb 100644 --- a/tox.ini +++ b/tox.ini @@ -43,18 +43,21 @@ commands = allowlist_externals = ruff skip_install = true commands = - ruff format {posargs} src/ tests/ sdk-schema/sync-sdk-schema.py + ruff format {posargs} src/ tests/ examples/plugins sdk-schema/sync-sdk-schema.py [testenv:lint] allowlist_externals = ruff skip_install = true commands = - ruff check {posargs} src/ tests/ + ruff check {posargs} src/ tests/ examples/plugins [testenv:typecheck] allowlist_externals = mypy commands = mypy --strict {posargs} src/ tests/ + # Examples folder is checked separately as a named package + # so mypy doesn't complain about multiple plugin.py files + mypy --strict {posargs} -p examples [testenv:sync-sdk-schema] allowlist_externals = python