Skip to content

Commit 70331ff

Browse files
authored
Add tools provider plugin example (#131)
1 parent 926be41 commit 70331ff

File tree

15 files changed

+751
-107
lines changed

15 files changed

+751
-107
lines changed
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
# `lmstudio/pydice`
1+
# `lmstudio/dice-tool`
22

3-
TODO: Example Python tools provider plugin
3+
Python tools provider plugin example
4+
5+
Running a local dev instance:
6+
7+
pdm run python -m lmstudio.plugin --dev examples/plugins/dice-tool
Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,100 @@
11
"""Example plugin that provide dice rolling tools."""
22

3-
# Not yet implemented, currently used to check plugins with no hooks defined
3+
import time
4+
5+
from random import randint
6+
from typing import TypedDict
7+
8+
from lmstudio.plugin import (
9+
BaseConfigSchema,
10+
ToolsProviderController,
11+
config_field,
12+
get_tool_call_context,
13+
)
14+
from lmstudio import ToolDefinition
15+
16+
17+
# Assigning ConfigSchema = SomeOtherSchemaClass also works
18+
class ConfigSchema(BaseConfigSchema):
19+
"""The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema."""
20+
21+
enable_inplace_status_demo: bool = config_field(
22+
label="Enable in-place status demo",
23+
hint="The plugin will run an in-place task status updating demo when invoked",
24+
default=True,
25+
)
26+
inplace_status_duration: float = config_field(
27+
label="In-place status total duration (s)",
28+
hint="The number of seconds to spend displaying the in-place task status update",
29+
default=5.0,
30+
)
31+
restrict_die_types: bool = config_field(
32+
label="Require polyhedral dice",
33+
hint="Require conventional polyhedral dice (4, 6, 8, 10, 12, 20, or 100 sides)",
34+
default=True,
35+
)
36+
37+
38+
# This example plugin has no global configuration settings defined.
39+
# For a type hinted plugin with no configuration settings of a given type,
40+
# BaseConfigSchema may be used in the hook controller type hint.
41+
# Defining a config schema subclass with no fields is also a valid approach.
42+
43+
44+
# When reporting multiple values from a tool call, dictionaries
45+
# are the preferred format, as the field names allow the LLM
46+
# to potentially interpret the result correctly.
47+
# Unlike parameter details, no return value schema is sent to the server,
48+
# so relevant information needs to be part of the JSON serialisation.
49+
class DiceRollResult(TypedDict):
50+
"""The result of a dice rolling request."""
51+
52+
rolls: list[int]
53+
total: int
54+
55+
56+
# Assigning list_provided_tools = some_other_callable also works
57+
async def list_provided_tools(
58+
ctl: ToolsProviderController[ConfigSchema, BaseConfigSchema],
59+
) -> list[ToolDefinition]:
60+
"""Naming the function 'list_provided_tools' implicitly registers it."""
61+
config = ctl.plugin_config
62+
if config.enable_inplace_status_demo:
63+
inplace_status_duration = config.inplace_status_duration
64+
else:
65+
inplace_status_duration = 0
66+
if config.restrict_die_types:
67+
permitted_sides = {4, 6, 8, 10, 12, 20, 100}
68+
else:
69+
permitted_sides = None
70+
71+
# Tool definitions may use any of the formats described in
72+
# https://lmstudio.ai/docs/python/agent/tools
73+
def roll_dice(count: int, sides: int) -> DiceRollResult:
74+
"""Roll a specified number of dice with specified number of faces.
75+
76+
For example, to roll 2 six-sided dice (i.e. 2d6), you should call the function
77+
`roll_dice` with the parameters { count: 2, sides: 6 }.
78+
"""
79+
if inplace_status_duration:
80+
tcc = get_tool_call_context()
81+
status_updates = (
82+
(tcc.notify_status, "Display status update in UI."),
83+
(tcc.notify_warning, "Display task warning in UI."),
84+
(tcc.notify_status, "Post-warning status update in UI."),
85+
)
86+
status_duration = inplace_status_duration / len(status_updates)
87+
for send_notification, status_text in status_updates:
88+
time.sleep(status_duration)
89+
send_notification(status_text)
90+
if permitted_sides and sides not in permitted_sides:
91+
expected_die_types = ",".join(map(str, sorted(permitted_sides)))
92+
err_msg = f"{sides} is not a conventional polyhedral die type ({expected_die_types})"
93+
raise ValueError(err_msg)
94+
rolls = [randint(1, sides) for _ in range(count)]
95+
return DiceRollResult(rolls=rolls, total=sum(rolls))
96+
97+
return [roll_dice]
98+
99+
100+
print(f"{__name__} initialized from {__file__}")
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# `lmstudio/pyprompt`
1+
# `lmstudio/prompt-prefix`
22

33
Python prompt preprocessing plugin example
44

5-
Note: there's no `python` runner in LM Studio yet, so use
6-
`python -m lmstudio.plugin --dev path/to/plugin` to run a dev instance
5+
Running a local dev instance:
6+
7+
pdm run python -m lmstudio.plugin --dev examples/plugins/prompt-prefix

examples/plugins/prompt-prefix/src/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ async def preprocess_prompt(
5656
status_updates
5757
)
5858
async with status_block.notify_aborted("Task genuinely cancelled."):
59-
for notification, status_text in status_updates:
59+
for send_notification, status_text in status_updates:
6060
await asyncio.sleep(status_duration)
61-
await notification(status_text)
61+
await send_notification(status_text)
6262

6363
modified_message = message.to_dict()
6464
# Add a prefix to all user messages

sdk-schema/sync-sdk-schema.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def _infer_schema_unions() -> None:
363363
"LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict",
364364
"RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest",
365365
"RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict",
366-
# Prettier plugin channel message names
366+
# Prettier prompt preprocessing plugin channel message names
367367
"PluginsChannelSetPromptPreprocessorToClientPacketPreprocess": "PromptPreprocessingRequest",
368368
"PluginsChannelSetPromptPreprocessorToClientPacketPreprocessDict": "PromptPreprocessingRequestDict",
369369
"PluginsChannelSetPromptPreprocessorToServerPacketAborted": "PromptPreprocessingAborted",
@@ -372,6 +372,25 @@ def _infer_schema_unions() -> None:
372372
"PluginsChannelSetPromptPreprocessorToServerPacketCompleteDict": "PromptPreprocessingCompleteDict",
373373
"PluginsChannelSetPromptPreprocessorToServerPacketError": "PromptPreprocessingError",
374374
"PluginsChannelSetPromptPreprocessorToServerPacketErrorDict": "PromptPreprocessingErrorDict",
375+
# Prettier tools provider plugin channel message names
376+
"PluginsChannelSetToolsProviderToClientPacketInitSession": "ProvideToolsInitSession",
377+
"PluginsChannelSetToolsProviderToClientPacketInitSessionDict": "ProvideToolsInitSessionDict",
378+
"PluginsChannelSetToolsProviderToClientPacketAbortToolCall": "ProvideToolsAbortCall",
379+
"PluginsChannelSetToolsProviderToClientPacketAbortToolCallDict": "ProvideToolsAbortCallDict",
380+
"PluginsChannelSetToolsProviderToClientPacketCallTool": "ProvideToolsCallTool",
381+
"PluginsChannelSetToolsProviderToClientPacketCallToolDict": "ProvideToolsCallToolDict",
382+
"PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailed": "ProvideToolsInitFailed",
383+
"PluginsChannelSetToolsProviderToServerPacketSessionInitializationFailedDict": "ProvideToolsInitFailedDict",
384+
"PluginsChannelSetToolsProviderToServerPacketSessionInitialized": "ProvideToolsInitialized",
385+
"PluginsChannelSetToolsProviderToServerPacketSessionInitializedDict": "ProvideToolsInitializedDict",
386+
"PluginsChannelSetToolsProviderToServerPacketToolCallComplete": "PluginToolCallComplete",
387+
"PluginsChannelSetToolsProviderToServerPacketToolCallCompleteDict": "PluginToolCallCompleteDict",
388+
"PluginsChannelSetToolsProviderToServerPacketToolCallError": "PluginToolCallError",
389+
"PluginsChannelSetToolsProviderToServerPacketToolCallErrorDict": "PluginToolCallErrorDict",
390+
"PluginsChannelSetToolsProviderToServerPacketToolCallStatus": "PluginToolCallStatus",
391+
"PluginsChannelSetToolsProviderToServerPacketToolCallStatusDict": "PluginToolCallStatusDict",
392+
"PluginsChannelSetToolsProviderToServerPacketToolCallWarn": "PluginToolCallWarn",
393+
"PluginsChannelSetToolsProviderToServerPacketToolCallWarnDict": "PluginToolCallWarnDict",
375394
# Prettier config handling type names
376395
"LlmRpcGetLoadConfigReturns": "SerializedKVConfigSettings",
377396
"LlmRpcGetLoadConfigReturnsDict": "SerializedKVConfigSettingsDict",

0 commit comments

Comments
 (0)