Skip to content

Commit d455734

Browse files
authored
Add async tool plugin example (#137)
1 parent 6660e52 commit d455734

File tree

9 files changed

+201
-4
lines changed

9 files changed

+201
-4
lines changed

examples/plugins/dice-tool/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# `lmstudio/dice-tool`
22

3-
Python tools provider plugin example
3+
Python tools provider plugin example showcasing synchronous tool definitions
4+
by adding support for random number generation by rolling simulated dice.
5+
6+
Also includes a tool call status update demo.
47

58
Running a local dev instance:
69

examples/plugins/dice-tool/src/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Example plugin that provide dice rolling tools."""
1+
"""Example plugin that provides dice rolling tools."""
22

33
import time
44

examples/plugins/prompt-prefix/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Python prompt preprocessing plugin example
44

5+
Also includes an in-place update demo for status block notifications.
6+
57
Running a local dev instance:
68

79
pdm run python -m lmstudio.plugin --dev examples/plugins/prompt-prefix
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# `lmstudio/wikipedia`
2+
3+
Python tools provider plugin example showcasing asynchronous tool definitions
4+
by adding support for searching Wikipedia and retrieving specific pages.
5+
6+
Running a local dev instance:
7+
8+
pdm run python -m lmstudio.plugin --dev examples/plugins/wikipedia
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"type": "plugin",
3+
"runner": "python",
4+
"owner": "lmstudio",
5+
"name": "py-wikipedia",
6+
"revision": 1
7+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Example plugin that provides tools for querying Wikipedia."""
2+
3+
from typing import Any, TypeAlias, TypedDict
4+
5+
from lmstudio.plugin import (
6+
BaseConfigSchema,
7+
ToolsProviderController,
8+
config_field,
9+
get_tool_call_context_async,
10+
)
11+
from lmstudio import ToolDefinition
12+
13+
# Python plugins don't support dependency declarations yet,
14+
# but the lmstudio SDK is always available and uses httpx
15+
import httpx
16+
17+
18+
# Assigning ConfigSchema = SomeOtherSchemaClass also works
19+
class ConfigSchema(BaseConfigSchema):
20+
"""The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema."""
21+
22+
wikipedia_base_url: str = config_field(
23+
label="Wikipedia Base URL",
24+
hint="The base URL for the Wikipedia API.",
25+
default="https://en.wikipedia.org",
26+
)
27+
28+
29+
# This example plugin has no global configuration settings defined.
30+
# For a type hinted plugin with no configuration settings of a given type,
31+
# BaseConfigSchema may be used in the hook controller type hint.
32+
# Defining a config schema subclass with no fields is also a valid approach.
33+
34+
35+
# When reporting multiple values from a tool call, dictionaries
36+
# are the preferred format, as the field names allow the LLM
37+
# to potentially interpret the result correctly.
38+
# Unlike parameter details, no return value schema is sent to the server,
39+
# so relevant information needs to be part of the JSON serialisation.
40+
class WikipediaSearchEntry(TypedDict):
41+
"""A single entry in a Wikipedia search result."""
42+
43+
title: str
44+
summary: str
45+
page_id: int
46+
47+
48+
class WikipediaSearchResult(TypedDict):
49+
"""The collected results of a Wikipedia search."""
50+
51+
results: list[WikipediaSearchEntry]
52+
hint: str
53+
54+
55+
class WikipediaPage(TypedDict):
56+
"""Details of a retrieved wikipedia page."""
57+
58+
title: str
59+
content: str
60+
61+
62+
ErrorResult: TypeAlias = str | dict[str, Any]
63+
64+
PAGE_RETRIEVAL_HINT = """\
65+
If any of the search results are relevant, ALWAYS use `get_wikipedia_page` to retrieve
66+
the full content of the page using the `page_id`. The `summary` is just a brief
67+
snippet and can have missing information. If not, try to search again using a more
68+
canonical term, or search for a different term that is more likely to contain the relevant
69+
information.
70+
"""
71+
72+
73+
def _strip_search_markup(text: str) -> str:
74+
"""Remove search markup inserted by Wikipedia API."""
75+
return text.replace('<span class="searchmatch">', "").replace("</span>", "")
76+
77+
78+
# Assigning list_provided_tools = some_other_callable also works
79+
async def list_provided_tools(
80+
ctl: ToolsProviderController[ConfigSchema, BaseConfigSchema],
81+
) -> list[ToolDefinition]:
82+
"""Naming the function 'list_provided_tools' implicitly registers it."""
83+
base_url = httpx.URL(ctl.plugin_config.wikipedia_base_url)
84+
api_url = base_url.join("/w/api.php")
85+
86+
async def _query_wikipedia(
87+
query_type: str, query_params: dict[str, Any]
88+
) -> tuple[Any, ErrorResult | None]:
89+
tcc = get_tool_call_context_async()
90+
await tcc.notify_status(f"Fetching {query_type} from Wikipedia...")
91+
async with httpx.AsyncClient() as web_client:
92+
result = await web_client.get(api_url, params=query_params)
93+
if result.status_code != httpx.codes.OK:
94+
warning_message = f"Failed to fetch {query_type} from Wikipedia (status: {result.status_code})"
95+
await tcc.notify_warning(warning_message)
96+
return None, f"Error: {warning_message}"
97+
data = result.json()
98+
err_data = data.get("error", None)
99+
if err_data is not None:
100+
warning_message = f"Wikipedia API returned an error: ${err_data['info']}"
101+
await tcc.notify_warning(warning_message)
102+
return None, err_data
103+
return data, None
104+
105+
# Tool definitions may use any of the formats described in
106+
# https://lmstudio.ai/docs/python/agent/tools
107+
async def search_wikipedia(query: str) -> WikipediaSearchResult | ErrorResult:
108+
"""Searches wikipedia using the given `query` string.
109+
110+
Returns a list of search results. Each search result contains
111+
a `title`, a `summary`, and a `page_id` which can be used to
112+
retrieve the full page content using get_wikipedia_page.
113+
114+
Note: this tool searches using Wikipedia, meaning, instead of using natural language queries,
115+
you should search for terms that you expect there will be an Wikipedia article of. For
116+
example, if the user asks about "the inventions of Thomas Edison", don't search for "what are
117+
the inventions of Thomas Edison". Instead, search for "Thomas Edison".
118+
119+
If a particular query did not return a result that you expect, you should try to search again
120+
using a more canonical term, or search for a different term that is more likely to contain the
121+
relevant information.
122+
123+
ALWAYS use `get_wikipedia_page` to retrieve the full content of the page afterwards. NEVER
124+
try to answer merely based on summary in the search results.
125+
"""
126+
search_params = {
127+
"action": "query",
128+
"list": "search",
129+
"srsearch": query,
130+
"format": "json",
131+
"utf8": "1",
132+
}
133+
data, error = await _query_wikipedia("search results", search_params)
134+
if error is not None:
135+
return error
136+
raw_results = data["query"]["search"]
137+
results = [
138+
WikipediaSearchEntry(
139+
title=r["title"],
140+
summary=_strip_search_markup(r["snippet"]),
141+
page_id=int(r["pageid"]),
142+
)
143+
for r in raw_results
144+
]
145+
return WikipediaSearchResult(results=results, hint=PAGE_RETRIEVAL_HINT)
146+
147+
async def get_wikipedia_page(page_id: int) -> WikipediaPage | ErrorResult:
148+
"""Retrieves the full content of a Wikipedia page using the given `page_id`.
149+
150+
Returns the title and content of a page.
151+
Use `search_wikipedia` first to get the `page_id`.
152+
"""
153+
str_page_id = str(page_id)
154+
fetch_params = {
155+
"action": "query",
156+
"prop": "extracts",
157+
"explaintext": "1",
158+
"pageids": str_page_id,
159+
"format": "json",
160+
"utf8": "1",
161+
}
162+
data, error = await _query_wikipedia("page content", fetch_params)
163+
if error is not None:
164+
return error
165+
raw_page = data["query"]["pages"][str_page_id]
166+
title = raw_page["title"]
167+
content = raw_page.get("extract", None)
168+
if content is None:
169+
content = "No content available for this page."
170+
return WikipediaPage(title=title, content=content)
171+
172+
return [search_wikipedia, get_wikipedia_page]
173+
174+
175+
print(f"{__name__} initialized from {__file__}")

src/lmstudio/plugin/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# * refactor to allow hook invocation error handling to be common across hook invocation tasks
1818
# * [DONE] gracefully handle app termination while a dev plugin is still running
1919
# * [DONE] gracefully handle using Ctrl-C to terminate a running dev plugin
20-
# * add async tool handling support to SDK (as part of adding .act() to the async API)
20+
# * [DONE] add async tool handling support to SDK (as part of adding .act() to the async API)
2121
#
2222
# Controller APIs (may be limited to relevant hook controllers)
2323
#
@@ -50,7 +50,7 @@
5050
#
5151
# Tools provider hook
5252
# * [DONE] add example synchronous tool plugin (dice rolling)
53-
# * add example asynchronous tool plugin (Wikipedia lookup) (note: requires async tool support in SDK)
53+
# * [DONE] add example asynchronous tool plugin (Wikipedia lookup) (note: requires async tool support in SDK)
5454
# * [DONE] define the channel, hook invocation task and hook invocation controller for this hook
5555
# * [DONE] main request initiation message is "InitSession" (with Initialized/Failed responses)
5656
# * [DONE] handle "AbortToolCall" requests from server

src/lmstudio/plugin/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
"ToolCallContext",
1616
"ToolsProviderController",
1717
"get_tool_call_context",
18+
"get_tool_call_context_async",
1819
]

src/lmstudio/plugin/hooks/tools_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"ToolsProviderHook",
6060
"run_tools_provider",
6161
"get_tool_call_context",
62+
"get_tool_call_context_async",
6263
]
6364

6465

0 commit comments

Comments
 (0)