diff --git a/examples/plugins/dice-tool/README.md b/examples/plugins/dice-tool/README.md index 2d3e1f2..5330e3f 100644 --- a/examples/plugins/dice-tool/README.md +++ b/examples/plugins/dice-tool/README.md @@ -1,3 +1,7 @@ -# `lmstudio/pydice` +# `lmstudio/dice-tool` -TODO: Example Python tools provider plugin +Python tools provider plugin example + +Running a local dev instance: + + pdm run python -m lmstudio.plugin --dev examples/plugins/dice-tool diff --git a/examples/plugins/dice-tool/src/plugin.py b/examples/plugins/dice-tool/src/plugin.py index ea2edf2..b878ab3 100644 --- a/examples/plugins/dice-tool/src/plugin.py +++ b/examples/plugins/dice-tool/src/plugin.py @@ -1,3 +1,100 @@ """Example plugin that provide dice rolling tools.""" -# Not yet implemented, currently used to check plugins with no hooks defined +import time + +from random import randint +from typing import TypedDict + +from lmstudio.plugin import ( + BaseConfigSchema, + ToolsProviderController, + config_field, + get_tool_call_context, +) +from lmstudio import ToolDefinition + + +# Assigning ConfigSchema = SomeOtherSchemaClass also works +class ConfigSchema(BaseConfigSchema): + """The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema.""" + + enable_inplace_status_demo: bool = config_field( + label="Enable in-place status demo", + hint="The plugin will run an in-place task status updating demo when invoked", + default=True, + ) + inplace_status_duration: float = config_field( + label="In-place status total duration (s)", + hint="The number of seconds to spend displaying the in-place task status update", + default=5.0, + ) + restrict_die_types: bool = config_field( + label="Require polyhedral dice", + hint="Require conventional polyhedral dice (4, 6, 8, 10, 12, 20, or 100 sides)", + default=True, + ) + + +# This example plugin has no global configuration settings defined. +# For a type hinted plugin with no configuration settings of a given type, +# BaseConfigSchema may be used in the hook controller type hint. +# Defining a config schema subclass with no fields is also a valid approach. + + +# When reporting multiple values from a tool call, dictionaries +# are the preferred format, as the field names allow the LLM +# to potentially interpret the result correctly. +# Unlike parameter details, no return value schema is sent to the server, +# so relevant information needs to be part of the JSON serialisation. +class DiceRollResult(TypedDict): + """The result of a dice rolling request.""" + + rolls: list[int] + total: int + + +# Assigning list_provided_tools = some_other_callable also works +async def list_provided_tools( + ctl: ToolsProviderController[ConfigSchema, BaseConfigSchema], +) -> list[ToolDefinition]: + """Naming the function 'list_provided_tools' implicitly registers it.""" + config = ctl.plugin_config + if config.enable_inplace_status_demo: + inplace_status_duration = config.inplace_status_duration + else: + inplace_status_duration = 0 + if config.restrict_die_types: + permitted_sides = {4, 6, 8, 10, 12, 20, 100} + else: + permitted_sides = None + + # Tool definitions may use any of the formats described in + # https://lmstudio.ai/docs/python/agent/tools + def roll_dice(count: int, sides: int) -> DiceRollResult: + """Roll a specified number of dice with specified number of faces. + + For example, to roll 2 six-sided dice (i.e. 2d6), you should call the function + `roll_dice` with the parameters { count: 2, sides: 6 }. + """ + if inplace_status_duration: + tcc = get_tool_call_context() + status_updates = ( + (tcc.notify_status, "Display status update in UI."), + (tcc.notify_warning, "Display task warning in UI."), + (tcc.notify_status, "Post-warning status update in UI."), + ) + status_duration = inplace_status_duration / len(status_updates) + for send_notification, status_text in status_updates: + time.sleep(status_duration) + send_notification(status_text) + if permitted_sides and sides not in permitted_sides: + expected_die_types = ",".join(map(str, sorted(permitted_sides))) + err_msg = f"{sides} is not a conventional polyhedral die type ({expected_die_types})" + raise ValueError(err_msg) + rolls = [randint(1, sides) for _ in range(count)] + return DiceRollResult(rolls=rolls, total=sum(rolls)) + + return [roll_dice] + + +print(f"{__name__} initialized from {__file__}") diff --git a/examples/plugins/prompt-prefix/README.md b/examples/plugins/prompt-prefix/README.md index 03fec2f..eecba27 100644 --- a/examples/plugins/prompt-prefix/README.md +++ b/examples/plugins/prompt-prefix/README.md @@ -1,6 +1,7 @@ -# `lmstudio/pyprompt` +# `lmstudio/prompt-prefix` Python prompt preprocessing plugin example -Note: there's no `python` runner in LM Studio yet, so use -`python -m lmstudio.plugin --dev path/to/plugin` to run a dev instance +Running a local dev instance: + + pdm run python -m lmstudio.plugin --dev examples/plugins/prompt-prefix diff --git a/examples/plugins/prompt-prefix/src/plugin.py b/examples/plugins/prompt-prefix/src/plugin.py index 4521ebb..22ae2b4 100644 --- a/examples/plugins/prompt-prefix/src/plugin.py +++ b/examples/plugins/prompt-prefix/src/plugin.py @@ -56,9 +56,9 @@ async def preprocess_prompt( status_updates ) async with status_block.notify_aborted("Task genuinely cancelled."): - for notification, status_text in status_updates: + for send_notification, status_text in status_updates: await asyncio.sleep(status_duration) - await notification(status_text) + await send_notification(status_text) modified_message = message.to_dict() # Add a prefix to all user messages diff --git a/sdk-schema/sync-sdk-schema.py b/sdk-schema/sync-sdk-schema.py index 51f7a3e..5208b22 100755 --- a/sdk-schema/sync-sdk-schema.py +++ b/sdk-schema/sync-sdk-schema.py @@ -363,7 +363,7 @@ def _infer_schema_unions() -> None: "LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict", "RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest", "RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict", - # Prettier plugin channel message names + # Prettier prompt preprocessing plugin channel message names "PluginsChannelSetPromptPreprocessorToClientPacketPreprocess": "PromptPreprocessingRequest", "PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict": "PromptPreprocessingRequestDict", "PluginsChannelSetPromptPreprocessorToServerPacketAborted": "PromptPreprocessingAborted", @@ -372,6 +372,25 @@ def _infer_schema_unions() -> None: "PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict": "PromptPreprocessingCompleteDict", "PluginsChannelSetPromptPreprocessorToServerPacketError": "PromptPreprocessingError", "PluginsChannelSetPromptPreprocessorToServerPacketErrorDict": "PromptPreprocessingErrorDict", + # Prettier tools provider plugin channel message names + "PluginsChannelSetToolsProviderToClientPacketInitSession": "ProvideToolsInitSession", + "PluginsChannelSetToolsProviderToClientPacketInitSessionDict": "ProvideToolsInitSessionDict", + "PluginsChannelSetToolsProviderToClientPacketAbortToolCall": "ProvideToolsAbortCall", + "PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict": "ProvideToolsAbortCallDict", + "PluginsChannelSetToolsProviderToClientPacketCallTool": "ProvideToolsCallTool", + "PluginsChannelSetToolsProviderToClientPacketCallToolDict": "ProvideToolsCallToolDict", + "PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed": "ProvideToolsInitFailed", + "PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict": "ProvideToolsInitFailedDict", + "PluginsChannelSetToolsProviderToServerPacketSessionInitialized": "ProvideToolsInitialized", + "PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict": "ProvideToolsInitializedDict", + "PluginsChannelSetToolsProviderToServerPacketToolCallComplete": "PluginToolCallComplete", + "PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict": "PluginToolCallCompleteDict", + "PluginsChannelSetToolsProviderToServerPacketToolCallError": "PluginToolCallError", + "PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict": "PluginToolCallErrorDict", + "PluginsChannelSetToolsProviderToServerPacketToolCallStatus": "PluginToolCallStatus", + "PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict": "PluginToolCallStatusDict", + "PluginsChannelSetToolsProviderToServerPacketToolCallWarn": "PluginToolCallWarn", + "PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict": "PluginToolCallWarnDict", # Prettier config handling type names "LlmRpcGetLoadConfigReturns": "SerializedKVConfigSettings", "LlmRpcGetLoadConfigReturnsDict": "SerializedKVConfigSettingsDict", diff --git a/src/lmstudio/_sdk_models/__init__.py b/src/lmstudio/_sdk_models/__init__.py index 5c56e37..02a3cdf 100644 --- a/src/lmstudio/_sdk_models/__init__.py +++ b/src/lmstudio/_sdk_models/__init__.py @@ -380,6 +380,14 @@ "ParsedFileIdentifierLocalDict", "PluginManifest", "PluginManifestDict", + "PluginToolCallComplete", + "PluginToolCallCompleteDict", + "PluginToolCallError", + "PluginToolCallErrorDict", + "PluginToolCallStatus", + "PluginToolCallStatusDict", + "PluginToolCallWarn", + "PluginToolCallWarnDict", "PluginsChannelRegisterDevelopmentPluginCreationParameter", "PluginsChannelRegisterDevelopmentPluginCreationParameterDict", "PluginsChannelRegisterDevelopmentPluginToClientPacketReady", @@ -420,26 +428,8 @@ "PluginsChannelSetPredictionLoopHandlerToServerPacketErrorDict", "PluginsChannelSetPromptPreprocessorToClientPacketAbort", "PluginsChannelSetPromptPreprocessorToClientPacketAbortDict", - "PluginsChannelSetToolsProviderToClientPacketAbortToolCall", - "PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict", - "PluginsChannelSetToolsProviderToClientPacketCallTool", - "PluginsChannelSetToolsProviderToClientPacketCallToolDict", "PluginsChannelSetToolsProviderToClientPacketDiscardSession", "PluginsChannelSetToolsProviderToClientPacketDiscardSessionDict", - "PluginsChannelSetToolsProviderToClientPacketInitSession", - "PluginsChannelSetToolsProviderToClientPacketInitSessionDict", - "PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed", - "PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict", - "PluginsChannelSetToolsProviderToServerPacketSessionInitialized", - "PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict", - "PluginsChannelSetToolsProviderToServerPacketToolCallComplete", - "PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict", - "PluginsChannelSetToolsProviderToServerPacketToolCallError", - "PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict", - "PluginsChannelSetToolsProviderToServerPacketToolCallStatus", - "PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict", - "PluginsChannelSetToolsProviderToServerPacketToolCallWarn", - "PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict", "PluginsRpcProcessingGetOrLoadModelParameter", "PluginsRpcProcessingGetOrLoadModelParameterDict", "PluginsRpcProcessingGetOrLoadModelReturns", @@ -528,6 +518,16 @@ "PromptPreprocessingErrorDict", "PromptPreprocessingRequest", "PromptPreprocessingRequestDict", + "ProvideToolsAbortCall", + "ProvideToolsAbortCallDict", + "ProvideToolsCallTool", + "ProvideToolsCallToolDict", + "ProvideToolsInitFailed", + "ProvideToolsInitFailedDict", + "ProvideToolsInitSession", + "ProvideToolsInitSessionDict", + "ProvideToolsInitialized", + "ProvideToolsInitializedDict", "PseudoDiagnostics", "PseudoDiagnosticsChannelStreamLogs", "PseudoDiagnosticsChannelStreamLogsDict", @@ -5214,8 +5214,8 @@ class PluginsChannelSetToolsProviderToClientPacketDiscardSessionDict(TypedDict): sessionId: str -class PluginsChannelSetToolsProviderToClientPacketCallTool( - LMStudioStruct["PluginsChannelSetToolsProviderToClientPacketCallToolDict"], +class ProvideToolsCallTool( + LMStudioStruct["ProvideToolsCallToolDict"], kw_only=True, tag_field="type", tag="callTool", @@ -5227,7 +5227,7 @@ class PluginsChannelSetToolsProviderToClientPacketCallTool( parameters: JsonSerializable -class PluginsChannelSetToolsProviderToClientPacketCallToolDict(TypedDict): +class ProvideToolsCallToolDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToClientPacketCallTool. NOTE: Multi-word keys are defined using their camelCase form, @@ -5241,8 +5241,8 @@ class PluginsChannelSetToolsProviderToClientPacketCallToolDict(TypedDict): parameters: JsonSerializable -class PluginsChannelSetToolsProviderToClientPacketAbortToolCall( - LMStudioStruct["PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict"], +class ProvideToolsAbortCall( + LMStudioStruct["ProvideToolsAbortCallDict"], kw_only=True, tag_field="type", tag="abortToolCall", @@ -5254,7 +5254,7 @@ class PluginsChannelSetToolsProviderToClientPacketAbortToolCall( call_id: str = field(name="callId") -class PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict(TypedDict): +class ProvideToolsAbortCallDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToClientPacketAbortToolCall. NOTE: Multi-word keys are defined using their camelCase form, @@ -5266,10 +5266,8 @@ class PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict(TypedDict): callId: str -class PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed( - LMStudioStruct[ - "PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict" - ], +class ProvideToolsInitFailed( + LMStudioStruct["ProvideToolsInitFailedDict"], kw_only=True, tag_field="type", tag="sessionInitializationFailed", @@ -5281,9 +5279,7 @@ class PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed( error: SerializedLMSExtendedError -class PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict( - TypedDict -): +class ProvideToolsInitFailedDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed. NOTE: Multi-word keys are defined using their camelCase form, @@ -5295,8 +5291,8 @@ class PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDic error: SerializedLMSExtendedErrorDict -class PluginsChannelSetToolsProviderToServerPacketToolCallComplete( - LMStudioStruct["PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict"], +class PluginToolCallComplete( + LMStudioStruct["PluginToolCallCompleteDict"], kw_only=True, tag_field="type", tag="toolCallComplete", @@ -5309,7 +5305,7 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallComplete( result: JsonSerializable -class PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict(TypedDict): +class PluginToolCallCompleteDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketToolCallComplete. NOTE: Multi-word keys are defined using their camelCase form, @@ -5322,8 +5318,8 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict(TypedDict result: JsonSerializable -class PluginsChannelSetToolsProviderToServerPacketToolCallError( - LMStudioStruct["PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict"], +class PluginToolCallError( + LMStudioStruct["PluginToolCallErrorDict"], kw_only=True, tag_field="type", tag="toolCallError", @@ -5336,7 +5332,7 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallError( error: SerializedLMSExtendedError -class PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict(TypedDict): +class PluginToolCallErrorDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketToolCallError. NOTE: Multi-word keys are defined using their camelCase form, @@ -5349,8 +5345,8 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict(TypedDict): error: SerializedLMSExtendedErrorDict -class PluginsChannelSetToolsProviderToServerPacketToolCallStatus( - LMStudioStruct["PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict"], +class PluginToolCallStatus( + LMStudioStruct["PluginToolCallStatusDict"], kw_only=True, tag_field="type", tag="toolCallStatus", @@ -5363,7 +5359,7 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallStatus( status_text: str = field(name="statusText") -class PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict(TypedDict): +class PluginToolCallStatusDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketToolCallStatus. NOTE: Multi-word keys are defined using their camelCase form, @@ -5376,8 +5372,8 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict(TypedDict): statusText: str -class PluginsChannelSetToolsProviderToServerPacketToolCallWarn( - LMStudioStruct["PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict"], +class PluginToolCallWarn( + LMStudioStruct["PluginToolCallWarnDict"], kw_only=True, tag_field="type", tag="toolCallWarn", @@ -5390,7 +5386,7 @@ class PluginsChannelSetToolsProviderToServerPacketToolCallWarn( warn_text: str = field(name="warnText") -class PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict(TypedDict): +class PluginToolCallWarnDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketToolCallWarn. NOTE: Multi-word keys are defined using their camelCase form, @@ -7946,8 +7942,8 @@ class PluginsChannelSetPredictionLoopHandlerToClientPacketHandlePredictionLoopDi token: str -class PluginsChannelSetToolsProviderToClientPacketInitSession( - LMStudioStruct["PluginsChannelSetToolsProviderToClientPacketInitSessionDict"], +class ProvideToolsInitSession( + LMStudioStruct["ProvideToolsInitSessionDict"], kw_only=True, tag_field="type", tag="initSession", @@ -7961,7 +7957,7 @@ class PluginsChannelSetToolsProviderToClientPacketInitSession( session_id: str = field(name="sessionId") -class PluginsChannelSetToolsProviderToClientPacketInitSessionDict(TypedDict): +class ProvideToolsInitSessionDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToClientPacketInitSession. NOTE: Multi-word keys are defined using their camelCase form, @@ -9129,15 +9125,15 @@ class PseudoPluginsChannelSetPredictionLoopHandlerDict(TypedDict): PluginsChannelSetToolsProviderToClientPacket = ( - PluginsChannelSetToolsProviderToClientPacketInitSession + ProvideToolsInitSession | PluginsChannelSetToolsProviderToClientPacketDiscardSession - | PluginsChannelSetToolsProviderToClientPacketCallTool - | PluginsChannelSetToolsProviderToClientPacketAbortToolCall + | ProvideToolsCallTool + | ProvideToolsAbortCall ) PluginsChannelSetToolsProviderToClientPacketDict = ( - PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict - | PluginsChannelSetToolsProviderToClientPacketCallToolDict - | PluginsChannelSetToolsProviderToClientPacketInitSessionDict + ProvideToolsAbortCallDict + | ProvideToolsCallToolDict + | ProvideToolsInitSessionDict | PluginsChannelSetToolsProviderToClientPacketDiscardSessionDict ) PluginsChannelSetGeneratorToServerPacket = ( @@ -9339,10 +9335,8 @@ class LlmToolUseSettingToolArrayDict(TypedDict): force: NotRequired[bool | None] -class PluginsChannelSetToolsProviderToServerPacketSessionInitialized( - LMStudioStruct[ - "PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict" - ], +class ProvideToolsInitialized( + LMStudioStruct["ProvideToolsInitializedDict"], kw_only=True, tag_field="type", tag="sessionInitialized", @@ -9354,7 +9348,7 @@ class PluginsChannelSetToolsProviderToServerPacketSessionInitialized( tool_definitions: Sequence[LlmTool] = field(name="toolDefinitions") -class PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict(TypedDict): +class ProvideToolsInitializedDict(TypedDict): """Corresponding typed dictionary definition for PluginsChannelSetToolsProviderToServerPacketSessionInitialized. NOTE: Multi-word keys are defined using their camelCase form, @@ -9486,20 +9480,20 @@ class ProcessingUpdateToolStatusUpdateDict(TypedDict): PluginsChannelSetToolsProviderToServerPacket = ( - PluginsChannelSetToolsProviderToServerPacketSessionInitialized - | PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed - | PluginsChannelSetToolsProviderToServerPacketToolCallComplete - | PluginsChannelSetToolsProviderToServerPacketToolCallError - | PluginsChannelSetToolsProviderToServerPacketToolCallStatus - | PluginsChannelSetToolsProviderToServerPacketToolCallWarn + ProvideToolsInitialized + | ProvideToolsInitFailed + | PluginToolCallComplete + | PluginToolCallError + | PluginToolCallStatus + | PluginToolCallWarn ) PluginsChannelSetToolsProviderToServerPacketDict = ( - PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict - | PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict - | PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict - | PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict - | PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict - | PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict + PluginToolCallWarnDict + | PluginToolCallStatusDict + | PluginToolCallErrorDict + | PluginToolCallCompleteDict + | ProvideToolsInitializedDict + | ProvideToolsInitFailedDict ) diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 47eb7ff..8a16f09 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -47,10 +47,6 @@ from ._logging import LogEventContext, new_logger -# Allow the core client websocket management to be shared across all SDK interaction APIs -# See https://discuss.python.org/t/daemon-threads-and-background-task-termination/77604 -# (Note: this implementation has the elements needed to run on *current* Python versions -# and omits the generalised features that the SDK doesn't need) T = TypeVar("T") __all__ = [ @@ -100,7 +96,7 @@ async def __aenter__(self) -> Self: async def __aexit__(self, *args: Any) -> None: await self.request_termination() - with move_on_after(self.TERMINATION_TIMEOUT): + with move_on_after(self.TERMINATION_TIMEOUT, shield=True): await self._terminated.wait() @classmethod diff --git a/src/lmstudio/_ws_thread.py b/src/lmstudio/_ws_thread.py index acfed06..8f2a24d 100644 --- a/src/lmstudio/_ws_thread.py +++ b/src/lmstudio/_ws_thread.py @@ -24,6 +24,10 @@ from ._logging import new_logger, LogEventContext from ._ws_impl import AsyncTaskManager, AsyncWebsocketHandler +# Allow the core client websocket management to be shared across all SDK interaction APIs +# See https://discuss.python.org/t/daemon-threads-and-background-task-termination/77604 +# (Note: this implementation has the elements needed to run on *current* Python versions +# and omits the generalised features that the SDK doesn't need) T = TypeVar("T") diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index ffaedcd..ee947d6 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -83,6 +83,7 @@ PromptProcessingCallback, RemoteCallHandler, ResponseSchema, + SendMessageAsync, TModelInfo, check_model_namespace, load_struct, @@ -132,7 +133,7 @@ def __init__( channel_id: int, get_message: Callable[[], Awaitable[Any]], endpoint: ChannelEndpoint[T, Any, Any], - send_json: Callable[[DictObject], Awaitable[None]], + send_json: SendMessageAsync, log_context: LogEventContext, ) -> None: """Initialize asynchronous websocket streaming channel.""" diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 39f9d02..76662e2 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -21,6 +21,7 @@ from typing import ( Any, Callable, + Coroutine, Generator, Generic, Iterable, @@ -192,6 +193,9 @@ DEFAULT_API_HOST = "localhost:1234" DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour +# Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility +SendMessageAsync: TypeAlias = Callable[[DictObject], Coroutine[Any, Any, None]] + UnstructuredPrediction: TypeAlias = str StructuredPrediction: TypeAlias = DictObject AnyPrediction = StructuredPrediction | UnstructuredPrediction @@ -1445,6 +1449,7 @@ def _handle_failed_tool_request( ) return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id) + # TODO: Reduce code duplication with the tools_provider plugin hook runner def request_tool_call( self, request: ToolCallRequest ) -> Callable[[], ToolCallResultData]: diff --git a/src/lmstudio/plugin/__init__.py b/src/lmstudio/plugin/__init__.py index 0c75e92..6ba23a0 100644 --- a/src/lmstudio/plugin/__init__.py +++ b/src/lmstudio/plugin/__init__.py @@ -17,18 +17,19 @@ # * refactor to allow hook invocation error handling to be common across hook invocation tasks # * [DONE] gracefully handle app termination while a dev plugin is still running # * [DONE] gracefully handle using Ctrl-C to terminate a running dev plugin +# * add async tool handling support to SDK (as part of adding .act() to the async API) # # Controller APIs (may be limited to relevant hook controllers) # # * [DONE] status blocks (both simple "done" blocks, and blocks with in-place updates) # * citation blocks # * debug info blocks -# * tool status reporting +# * [DONE] tool status reporting # * full content block display # * chat history retrieval # * model handle retrieval # * UI block sender name configuration -# * interactive tool call request confirmation +# * [not necessary (handled directly by UI)] interactive tool call request confirmation # # Prompt preprocessing hook # * [DONE] emit a status notification block when the demo plugin fires @@ -48,13 +49,14 @@ # * catch hook invocation failures and send "Error" responses # # Tools provider hook -# * add example plugin or plugins for this (probably both dice rolling and Wikipedia lookup) -# * define the channel, hook invocation task and hook invocation controller for this hook -# * main request initiation message is "InitSession" (with Initialized/Failed responses) -# * handle "Abort" requests from server (including sending "Aborted" responses) -# * handle "CallTool" requests from server (including sending "CallComplete"/"CallError" response) -# * handle "DiscardSession" requests from server -# * add controller API for tool call status and warning reporting +# * [DONE] add example synchronous tool plugin (dice rolling) +# * add example asynchronous tool plugin (Wikipedia lookup) (note: requires async tool support in SDK) +# * [DONE] define the channel, hook invocation task and hook invocation controller for this hook +# * [DONE] main request initiation message is "InitSession" (with Initialized/Failed responses) +# * [DONE] handle "AbortToolCall" requests from server +# * [DONE] handle "CallTool" requests from server (including sending "CallComplete"/"CallError" response) +# * [DONE] handle "DiscardSession" requests from server +# * [DONE] add controller API for tool call status and warning reporting # # Plugin config field definitions # * define approach for specifying plugin config field constraints and style options (e.g. numeric sliders) diff --git a/src/lmstudio/plugin/hooks/__init__.py b/src/lmstudio/plugin/hooks/__init__.py index e2cae58..74fe37b 100644 --- a/src/lmstudio/plugin/hooks/__init__.py +++ b/src/lmstudio/plugin/hooks/__init__.py @@ -9,7 +9,10 @@ # Available as lmstudio.plugin.* __all__ = [ + "AsyncToolCallContext", "PromptPreprocessorController", "TokenGeneratorController", + "ToolCallContext", "ToolsProviderController", + "get_tool_call_context", ] diff --git a/src/lmstudio/plugin/hooks/common.py b/src/lmstudio/plugin/hooks/common.py index 80289d0..dd588c6 100644 --- a/src/lmstudio/plugin/hooks/common.py +++ b/src/lmstudio/plugin/hooks/common.py @@ -9,7 +9,6 @@ from typing import ( Any, AsyncIterator, - Awaitable, Callable, Generic, TypeAlias, @@ -19,11 +18,10 @@ from anyio import move_on_after from ...async_api import _AsyncSession -from ...schemas import DictObject from ..._sdk_models import ( # TODO: Define aliases at schema generation time PluginsChannelSetGeneratorToClientPacketGenerate as TokenGenerationRequest, - PluginsChannelSetToolsProviderToClientPacketInitSession as ProvideToolsInitSession, + ProvideToolsInitSession, PromptPreprocessingRequest, SerializedKVConfigSettings, StatusStepStatus, @@ -51,7 +49,6 @@ class _AsyncSessionPlugins(_AsyncSession): TPluginConfigSchema = TypeVar("TPluginConfigSchema", bound=BaseConfigSchema) TGlobalConfigSchema = TypeVar("TGlobalConfigSchema", bound=BaseConfigSchema) TConfig = TypeVar("TConfig", bound=BaseConfigSchema) -SendMessageCallback: TypeAlias = Callable[[DictObject], Awaitable[Any]] class ServerRequestError(RuntimeError): diff --git a/src/lmstudio/plugin/hooks/prompt_preprocessor.py b/src/lmstudio/plugin/hooks/prompt_preprocessor.py index a1a3175..a04f468 100644 --- a/src/lmstudio/plugin/hooks/prompt_preprocessor.py +++ b/src/lmstudio/plugin/hooks/prompt_preprocessor.py @@ -30,6 +30,7 @@ ChannelEndpoint, ChannelFinishedEvent, ChannelRxEvent, + SendMessageAsync, load_struct, ) from ..._sdk_models import ( @@ -49,7 +50,6 @@ from .common import ( _AsyncSessionPlugins, HookController, - SendMessageCallback, ServerRequestError, StatusBlockController, TPluginConfigSchema, @@ -250,7 +250,7 @@ async def process_requests( break async def _abort_hook_invocation( - self, task_id: str, send_response: SendMessageCallback + self, task_id: str, send_json: SendMessageAsync ) -> None: """Abort the specified hook invocation (if it is still running).""" abort_event = self._abort_events.get(task_id, None) @@ -260,7 +260,7 @@ async def _abort_hook_invocation( type="aborted", taskId=task_id, ) - await send_response(response) + await send_json(response) async def _cancel_on_event( self, tg: TaskGroup, event: asyncio.Event, message: str @@ -301,7 +301,7 @@ async def _registered_hook_invocation( async def _invoke_hook( self, ctl: PromptPreprocessorController[TPluginConfigSchema, TGlobalConfigSchema], - send_response: SendMessageCallback, + send_json: SendMessageAsync, ) -> None: logger = self._logger task_id = ctl.task_id @@ -368,7 +368,7 @@ async def _invoke_hook( taskId=task_id, processed=response_dict, ) - await send_response(channel_message) + await send_json(channel_message) async def run_prompt_preprocessor( diff --git a/src/lmstudio/plugin/hooks/tools_provider.py b/src/lmstudio/plugin/hooks/tools_provider.py index 67d2e3b..2fa08c4 100644 --- a/src/lmstudio/plugin/hooks/tools_provider.py +++ b/src/lmstudio/plugin/hooks/tools_provider.py @@ -1,42 +1,560 @@ """Invoking and supporting tools provider hook implementations.""" -from typing import Any, Awaitable, Callable, Iterable +import asyncio +from contextvars import ContextVar +from dataclasses import dataclass +from traceback import format_tb +from typing import Any, Awaitable, Callable, Generic, Iterable, TypeAlias, TypeVar +from typing_extensions import ( + # Native in 3.11+ + assert_never, +) + +from anyio import create_task_group +from anyio.abc import TaskGroup +from msgspec import convert, to_builtins + +from ..._logging import new_logger, LogEventContext +from ...schemas import DictObject, EmptyDict from ...json_api import ( + ChannelCommonRxEvent, + ChannelEndpoint, + ChannelFinishedEvent, + ChannelRxEvent, + ChatResponseEndpoint, + ClientToolMap, + SendMessageAsync, ToolDefinition, ) - from ..._sdk_models import ( - # TODO: Define aliases at schema generation time - PluginsChannelSetToolsProviderToClientPacketInitSession as ProvideToolsInitSession, + PluginToolCallComplete, + PluginToolCallCompleteDict, + PluginToolCallErrorDict, + PluginToolCallStatusDict, + PluginToolCallWarnDict, + ProvideToolsInitSession, + ProvideToolsAbortCall, + ProvideToolsCallTool, + ProvideToolsInitFailedDict, + ProvideToolsInitializedDict, + SerializedLMSExtendedErrorDict, ) from ..config_schemas import BaseConfigSchema +from ..sdk_api import LMStudioPluginRuntimeError from .common import ( _AsyncSessionPlugins, HookController, + ServerRequestError, TPluginConfigSchema, TGlobalConfigSchema, ) # Available as lmstudio.plugin.hooks.* __all__ = [ + "AsyncToolCallContext", + "ToolCallContext", "ToolsProviderController", "ToolsProviderHook", "run_tools_provider", + "get_tool_call_context", ] +class ProvideToolsDiscardSessionEvent(ChannelRxEvent[str]): + pass + + +class ProvideToolsInitSessionEvent(ChannelRxEvent[ProvideToolsInitSession]): + pass + + +class ProvideToolsCallToolEvent(ChannelRxEvent[ProvideToolsCallTool]): + pass + + +class ProvideToolsAbortCallEvent(ChannelRxEvent[ProvideToolsAbortCall]): + pass + + +PromptPreprocessingRxEvent: TypeAlias = ( + ProvideToolsDiscardSessionEvent + | ProvideToolsInitSessionEvent + | ProvideToolsCallToolEvent + | ProvideToolsAbortCallEvent + | ChannelCommonRxEvent +) + + +class ToolsProviderEndpoint( + ChannelEndpoint[tuple[str, str], PromptPreprocessingRxEvent, EmptyDict] +): + """API channel endpoint to accept prompt preprocessing requests.""" + + _API_ENDPOINT = "setToolsProvider" + _NOTICE_PREFIX = "Providing tools" + + def __init__(self) -> None: + super().__init__({}) + + def iter_message_events( + self, contents: DictObject | None + ) -> Iterable[PromptPreprocessingRxEvent]: + match contents: + case None: + # Server can only terminate the link by closing the websocket + pass + case {"type": "discardSession", "sessionId": str(session_id)}: + yield ProvideToolsDiscardSessionEvent(session_id) + case {"type": "initSession"} as init_session_dict: + init_session = ProvideToolsInitSession._from_any_api_dict( + init_session_dict + ) + yield ProvideToolsInitSessionEvent(init_session) + case {"type": "callTool"} as tool_call_dict: + tool_call = ProvideToolsCallTool._from_any_api_dict(tool_call_dict) + yield ProvideToolsCallToolEvent(tool_call) + case {"type": "abortToolCall"} as abort_tool_call_dict: + abort_tool_call = ProvideToolsAbortCall._from_any_api_dict( + abort_tool_call_dict + ) + yield ProvideToolsAbortCallEvent(abort_tool_call) + case unmatched: + self.report_unknown_message(unmatched) + + def handle_rx_event(self, event: PromptPreprocessingRxEvent) -> None: + match event: + case ProvideToolsDiscardSessionEvent(session_id): + self._logger.debug(f"Terminating {session_id}", session_id=session_id) + case ProvideToolsInitSessionEvent(request): + self._logger.debug( + "Received tools session request", session_id=request.session_id + ) + case ProvideToolsCallToolEvent(request): + self._logger.debug( + "Received tool call request", + session_id=request.session_id, + call_id=request.call_id, + ) + case ProvideToolsAbortCallEvent(request): + self._logger.debug( + "Received tool abort request", + session_id=request.session_id, + call_id=request.call_id, + ) + case ChannelFinishedEvent(_): + pass + case _: + assert_never(event) + + class ToolsProviderController( HookController[ProvideToolsInitSession, TPluginConfigSchema, TGlobalConfigSchema] ): """API access for tools provider hook implementations.""" + def __init__( + self, + session: _AsyncSessionPlugins, + request: ProvideToolsInitSession, + plugin_config_schema: type[TPluginConfigSchema], + global_config_schema: type[TGlobalConfigSchema], + ) -> None: + """Initialize prompt preprocessor hook controller.""" + super().__init__(session, request, plugin_config_schema, global_config_schema) + self.session_id = request.session_id + ToolsProviderHook = Callable[ [ToolsProviderController[Any, Any]], Awaitable[Iterable[ToolDefinition]] ] +T = TypeVar("T") + + +class _BaseToolCallContext: + """API access to update a tool call UI status block in-place.""" + + def __init__( + self, + session_id: str, + call_id: str, + send_json: SendMessageAsync, + ) -> None: + """Initialize status block controller.""" + self.session_id = session_id + self.call_id = call_id + self._send_json = send_json + + def _make_status(self, message: str) -> PluginToolCallStatusDict: + return PluginToolCallStatusDict( + type="toolCallStatus", + sessionId=self.session_id, + callId=self.call_id, + statusText=message, + ) + + def _make_warning(self, message: str) -> PluginToolCallWarnDict: + return PluginToolCallWarnDict( + type="toolCallWarn", + sessionId=self.session_id, + callId=self.call_id, + warnText=message, + ) + + +class AsyncToolCallContext(_BaseToolCallContext): + """Asynchronous API access to update a tool call UI status block in-place.""" + + async def notify_status(self, message: str) -> None: + """Report tool progress update in the task status block.""" + await self._send_json(self._make_status(message)) + + async def notify_warning(self, message: str) -> None: + """Report tool warning in the task status block.""" + await self._send_json(self._make_warning(message)) + + +class ToolCallContext(_BaseToolCallContext): + """Synchronous API access to update a tool call UI status block in-place.""" + + def __init__( + self, + session_id: str, + call_id: str, + send_json: SendMessageAsync, + ) -> None: + """Initialize synchronous status block controller.""" + super().__init__(session_id, call_id, send_json) + # Sync call context is created in the plugin's async comms loop + self._loop = asyncio.get_running_loop() + + def _send_json_sync(self, data: DictObject) -> None: + future = asyncio.run_coroutine_threadsafe(self._send_json(data), self._loop) + future.result() + + def notify_status(self, message: str) -> None: + """Report tool progress update in the task status block.""" + self._send_json_sync(self._make_status(message)) + + def notify_warning(self, message: str) -> None: + """Report tool warning in the task status block.""" + self._send_json_sync(self._make_warning(message)) + + +_LMS_TOOL_CALL_SYNC: ContextVar[ToolCallContext] = ContextVar("_LMS_TOOL_CALL_SYNC") +_LMS_TOOL_CALL_ASYNC: ContextVar[AsyncToolCallContext] = ContextVar( + "_LMS_TOOL_CALL_ASYNC" +) + + +def get_tool_call_context() -> ToolCallContext: + """Get synchronous tool call context.""" + if _LMS_TOOL_CALL_ASYNC.get(None) is not None: + msg = "Use 'get_tool_call_context_async()' in asynchronous tool definition" + raise LMStudioPluginRuntimeError(msg) + return _LMS_TOOL_CALL_SYNC.get() + + +def get_tool_call_context_async() -> AsyncToolCallContext: + """Get asynchronous tool call context.""" + if _LMS_TOOL_CALL_SYNC.get(None) is not None: + msg = "Use 'get_tool_call_context()' in synchronous tool definition" + raise LMStudioPluginRuntimeError(msg) + return _LMS_TOOL_CALL_ASYNC.get() + + +class ToolCallHandler: + def __init__( + self, + plugin_name: str, + session_id: str, + provided_tools: ClientToolMap, + log_context: LogEventContext, + ) -> None: + self.plugin_name = plugin_name + self.session_id = session_id + self._provided_tools = provided_tools + self._queue: asyncio.Queue[ProvideToolsCallTool | None] = asyncio.Queue() + self._abort_events: dict[str, asyncio.Event] = {} + self._logger = logger = new_logger(__name__) + logger.update_context(log_context, session_id=session_id) + + async def _cancel_on_event( + self, tg: TaskGroup, event: asyncio.Event, message: str + ) -> None: + await event.wait() + self._logger.info(message) + tg.cancel_scope.cancel() + + async def start_tool_call(self, tool_call: ProvideToolsCallTool) -> None: + await self._queue.put(tool_call) + + # TODO: Reduce code duplication with the ChatResponseEndpoint definition + def _call_sync_tool( + self, + call_id: str, + sync_tool: Callable[..., Any], + kwds: DictObject, + send_json: SendMessageAsync, + ) -> Awaitable[PluginToolCallCompleteDict]: + # Ensure synchronous tools can't block the plugin's async comms thread + call_context = ToolCallContext(self.session_id, call_id, send_json) + + def _call_requested_tool() -> PluginToolCallCompleteDict: + assert _LMS_TOOL_CALL_ASYNC.get(None) is None + _LMS_TOOL_CALL_SYNC.set(call_context) + call_result = sync_tool(**kwds) + return PluginToolCallComplete( + session_id=self.session_id, + call_id=call_id, + result=call_result, + ).to_dict() + + return asyncio.to_thread(_call_requested_tool) + + async def _call_tool_implementation( + self, tool_call: ProvideToolsCallTool, send_json: SendMessageAsync + ) -> PluginToolCallCompleteDict: + # Find tool implementation + tool_name = tool_call.tool_name + tool_details = self._provided_tools.get(tool_name, None) + if tool_details is None: + raise ServerRequestError( + f"Plugin does not provide a tool named {tool_name!r}." + ) + # Validate parameters against their specification + params_struct, tool_impl = tool_details + raw_kwds = tool_call.parameters + try: + parsed_kwds = convert(raw_kwds, params_struct) + except Exception as exc: + err_msg = f"Failed to parse arguments for tool {tool_name}: {exc}" + raise ServerRequestError(err_msg) + kwds = to_builtins(parsed_kwds) + # TODO: Also support async tool definitions and invocation + return await self._call_sync_tool(tool_call.call_id, tool_impl, kwds, send_json) + + # TODO: Reduce code duplication with the ChatResponseEndpoint definition + async def _call_tool( + self, tool_call: ProvideToolsCallTool, send_json: SendMessageAsync + ) -> None: + call_id = tool_call.call_id + abort_events = self._abort_events + if call_id in abort_events: + err_msg = f"Tool call already in progress for {call_id} in session {self.session_id}" + raise ServerRequestError(err_msg) + abort_events[call_id] = abort_event = asyncio.Event() + logger = new_logger(__name__) + logger.update_context(self._logger.event_context, call_id=call_id) + try: + async with create_task_group() as tg: + tg.start_soon( + self._cancel_on_event, + tg, + abort_event, + f"Aborting tool_call {call_id}", + ) + logger.info(f"Running tool call {call_id}") + # TODO: Set up context variable for status/warning message sending + tool_call_response: PluginToolCallCompleteDict | PluginToolCallErrorDict + try: + tool_call_response = await self._call_tool_implementation( + tool_call, send_json + ) + except Exception as exc: + # Only catch regular exceptions, + # allowing the server to time out for client process termination events + err_msg = "Error calling tool implementation" + logger.error(err_msg, exc_info=True, exc=repr(exc)) + # TODO: Determine if it's worth sending the stack trace to the server + tool_name = tool_call.tool_name + ui_cause = f"{type(exc).__name__}: {exc}" + # Tool calling UI only displays the title, so also embed the cause directly + error_title = f"Error calling tool {tool_name} in plugin {self.plugin_name!r} ({ui_cause})" + error_details = SerializedLMSExtendedErrorDict( + title=error_title, + rootTitle=error_title, + cause=ui_cause, + stack="\n".join(format_tb(exc.__traceback__)), + ) + tool_call_response = PluginToolCallErrorDict( + type="toolCallError", + sessionId=self.session_id, + callId=call_id, + error=error_details, + ) + await send_json(tool_call_response) + tg.cancel_scope.cancel() + finally: + self._abort_events.pop(call_id, None) + + def abort_tool_call(self, call_id: str) -> None: + abort_event = self._abort_events.get(call_id) + if abort_event is not None: + abort_event.set() + # Any server notification will be sent from the tool calling task + + def _abort_all_calls(self) -> None: + for abort_event in self._abort_events.values(): + abort_event.set() + # Any server notifications will be sent from the tool calling tasks + + async def discard_session(self) -> None: + await self._queue.put(None) + + async def receive_tool_calls(self, send_message: SendMessageAsync) -> None: + session_queue = self._queue + try: + while True: + tool_call = await session_queue.get() + if tool_call is None: + break + await self._call_tool(tool_call, send_message) + finally: + self._abort_all_calls() + + +# TODO: Define a common "PluginHookHandler" base class +@dataclass() +class ToolsProvider(Generic[TPluginConfigSchema, TGlobalConfigSchema]): + """Handle accepting tools provider session requests.""" + + plugin_name: str + hook_impl: ToolsProviderHook + plugin_config_schema: type[TPluginConfigSchema] + global_config_schema: type[TGlobalConfigSchema] + + def __post_init__(self) -> None: + self._logger = logger = new_logger(__name__) + logger.update_context(plugin_name=self.plugin_name) + self._call_handlers: dict[str, ToolCallHandler] = {} + + async def process_requests( + self, ws_session: _AsyncSessionPlugins, notify_ready: Callable[[], Any] + ) -> None: + """Create plugin channel and wait for server requests.""" + logger = self._logger + endpoint = ToolsProviderEndpoint() + # Async API expects timeouts to be handled via task groups, + # so there's no default timeout to override when creating the channel + async with ws_session._create_channel(endpoint) as channel: + notify_ready() + logger.info("Opened channel to receive tools session requests...") + send_message = channel.send_message + async with create_task_group() as tg: + logger.debug("Waiting for tools session requests...") + async for contents in channel.rx_stream(): + logger.debug(f"Handling tools provider channel message: {contents}") + for event in endpoint.iter_message_events(contents): + logger.debug("Handling tools provider channel event") + endpoint.handle_rx_event(event) + match event: + case ProvideToolsDiscardSessionEvent(): + await self._discard_session(event.arg) + case ProvideToolsInitSessionEvent(): + logger.debug("Running tools listing hook") + ctl = ToolsProviderController( + ws_session, + event.arg, + self.plugin_config_schema, + self.global_config_schema, + ) + tg.start_soon(self._invoke_hook, ctl, send_message) + case ProvideToolsCallToolEvent(_): + tg.start_soon(self._call_tool, event.arg) + case ProvideToolsAbortCallEvent(_): + self._abort_tool_call(event.arg) + case ChannelFinishedEvent(_): + pass + case _: + assert_never(event) + if endpoint.is_finished: + break + + async def _discard_session(self, session_id: str) -> None: + """Abort the specified tools session (if it is still running).""" + call_handler = self._call_handlers.get(session_id, None) + if call_handler is not None: + await call_handler.discard_session() + + async def _call_tool(self, tool_call_request: ProvideToolsCallTool) -> None: + """Call the specified tool.""" + call_handler = self._call_handlers.get(tool_call_request.session_id, None) + if call_handler is not None: + await call_handler.start_tool_call(tool_call_request) + + def _abort_tool_call(self, abort_request: ProvideToolsAbortCall) -> None: + """Abort the specified tool call (if it is still running).""" + call_handler = self._call_handlers.get(abort_request.session_id, None) + if call_handler is not None: + call_handler.abort_tool_call(abort_request.call_id) + + async def _run_tools_session( + self, + session_id: str, + provided_tools: ClientToolMap, + send_json: SendMessageAsync, + ) -> None: + logger = self._logger + call_handlers = self._call_handlers + if session_id in call_handlers: + err_msg = f"Tools session already in progress for {session_id}" + raise ServerRequestError(err_msg) + call_handler = call_handlers[session_id] = ToolCallHandler( + self.plugin_name, session_id, provided_tools, self._logger.event_context + ) + try: + logger.info(f"Running tools session {session_id}") + await call_handler.receive_tool_calls(send_json) + finally: + call_handlers.pop(session_id, None) + logger.info(f"Terminated tools session {session_id}") + + async def _invoke_hook( + self, + ctl: ToolsProviderController[TPluginConfigSchema, TGlobalConfigSchema], + send_json: SendMessageAsync, + ) -> None: + logger = self._logger + session_id = ctl.session_id + error_details: SerializedLMSExtendedErrorDict | None = None + try: + plugin_tools_list = await self.hook_impl(ctl) + llm_tools_array, provided_tools = ChatResponseEndpoint.parse_tools( + plugin_tools_list + ) + llm_tools_list = llm_tools_array.to_dict()["tools"] + assert llm_tools_list is not None # Ensured by the parse_tools method + except Exception as exc: + err_msg = "Error calling tools listing hook" + logger.error(err_msg, exc_info=True, exc=repr(exc)) + # TODO: Determine if it's worth sending the stack trace to the server + error_title = f"Tools listing error in plugin {self.plugin_name!r}" + ui_cause = f"{err_msg}\n({type(exc).__name__}: {exc})" + error_details = SerializedLMSExtendedErrorDict( + title=error_title, + rootTitle=error_title, + cause=ui_cause, + stack="\n".join(format_tb(exc.__traceback__)), + ) + error_message = ProvideToolsInitFailedDict( + type="sessionInitializationFailed", + sessionId=session_id, + error=error_details, + ) + await send_json(error_message) + return + init_message = ProvideToolsInitializedDict( + type="sessionInitialized", + sessionId=session_id, + toolDefinitions=llm_tools_list, + ) + await send_json(init_message) + # Wait for further messages (until the session is discarded) + await self._run_tools_session(session_id, provided_tools, send_json) + async def run_tools_provider( plugin_name: str, @@ -47,4 +565,7 @@ async def run_tools_provider( notify_ready: Callable[[], Any], ) -> None: """Accept tools provider session requests.""" - raise NotImplementedError + tools_provider = ToolsProvider( + plugin_name, hook_impl, plugin_config_schema, global_config_schema + ) + await tools_provider.process_requests(session, notify_ready)