diff --git a/examples/plugins/dice-tool/README.md b/examples/plugins/dice-tool/README.md index 5330e3f..2b1f3ba 100644 --- a/examples/plugins/dice-tool/README.md +++ b/examples/plugins/dice-tool/README.md @@ -1,6 +1,9 @@ # `lmstudio/dice-tool` -Python tools provider plugin example +Python tools provider plugin example showcasing synchronous tool definitions +by adding support for random number generation by rolling simulated dice. + +Also includes a tool call status update demo. Running a local dev instance: diff --git a/examples/plugins/dice-tool/src/plugin.py b/examples/plugins/dice-tool/src/plugin.py index b878ab3..34d44aa 100644 --- a/examples/plugins/dice-tool/src/plugin.py +++ b/examples/plugins/dice-tool/src/plugin.py @@ -1,4 +1,4 @@ -"""Example plugin that provide dice rolling tools.""" +"""Example plugin that provides dice rolling tools.""" import time diff --git a/examples/plugins/prompt-prefix/README.md b/examples/plugins/prompt-prefix/README.md index eecba27..01a3cb9 100644 --- a/examples/plugins/prompt-prefix/README.md +++ b/examples/plugins/prompt-prefix/README.md @@ -2,6 +2,8 @@ Python prompt preprocessing plugin example +Also includes an in-place update demo for status block notifications. + Running a local dev instance: pdm run python -m lmstudio.plugin --dev examples/plugins/prompt-prefix diff --git a/examples/plugins/wikipedia/README.md b/examples/plugins/wikipedia/README.md new file mode 100644 index 0000000..18def3f --- /dev/null +++ b/examples/plugins/wikipedia/README.md @@ -0,0 +1,8 @@ +# `lmstudio/wikipedia` + +Python tools provider plugin example showcasing asynchronous tool definitions +by adding support for searching Wikipedia and retrieving specific pages. + +Running a local dev instance: + + pdm run python -m lmstudio.plugin --dev examples/plugins/wikipedia diff --git a/examples/plugins/wikipedia/manifest.json b/examples/plugins/wikipedia/manifest.json new file mode 100644 index 0000000..0a9ddc6 --- /dev/null +++ b/examples/plugins/wikipedia/manifest.json @@ -0,0 +1,7 @@ +{ + "type": "plugin", + "runner": "python", + "owner": "lmstudio", + "name": "py-wikipedia", + "revision": 1 +} diff --git a/examples/plugins/wikipedia/src/plugin.py b/examples/plugins/wikipedia/src/plugin.py new file mode 100644 index 0000000..e2f5097 --- /dev/null +++ b/examples/plugins/wikipedia/src/plugin.py @@ -0,0 +1,175 @@ +"""Example plugin that provides tools for querying Wikipedia.""" + +from typing import Any, TypeAlias, TypedDict + +from lmstudio.plugin import ( + BaseConfigSchema, + ToolsProviderController, + config_field, + get_tool_call_context_async, +) +from lmstudio import ToolDefinition + +# Python plugins don't support dependency declarations yet, +# but the lmstudio SDK is always available and uses httpx +import httpx + + +# Assigning ConfigSchema = SomeOtherSchemaClass also works +class ConfigSchema(BaseConfigSchema): + """The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema.""" + + wikipedia_base_url: str = config_field( + label="Wikipedia Base URL", + hint="The base URL for the Wikipedia API.", + default="https://en.wikipedia.org", + ) + + +# This example plugin has no global configuration settings defined. +# For a type hinted plugin with no configuration settings of a given type, +# BaseConfigSchema may be used in the hook controller type hint. +# Defining a config schema subclass with no fields is also a valid approach. + + +# When reporting multiple values from a tool call, dictionaries +# are the preferred format, as the field names allow the LLM +# to potentially interpret the result correctly. +# Unlike parameter details, no return value schema is sent to the server, +# so relevant information needs to be part of the JSON serialisation. +class WikipediaSearchEntry(TypedDict): + """A single entry in a Wikipedia search result.""" + + title: str + summary: str + page_id: int + + +class WikipediaSearchResult(TypedDict): + """The collected results of a Wikipedia search.""" + + results: list[WikipediaSearchEntry] + hint: str + + +class WikipediaPage(TypedDict): + """Details of a retrieved wikipedia page.""" + + title: str + content: str + + +ErrorResult: TypeAlias = str | dict[str, Any] + +PAGE_RETRIEVAL_HINT = """\ +If any of the search results are relevant, ALWAYS use `get_wikipedia_page` to retrieve +the full content of the page using the `page_id`. The `summary` is just a brief +snippet and can have missing information. If not, try to search again using a more +canonical term, or search for a different term that is more likely to contain the relevant +information. +""" + + +def _strip_search_markup(text: str) -> str: + """Remove search markup inserted by Wikipedia API.""" + return text.replace('', "").replace("", "") + + +# Assigning list_provided_tools = some_other_callable also works +async def list_provided_tools( + ctl: ToolsProviderController[ConfigSchema, BaseConfigSchema], +) -> list[ToolDefinition]: + """Naming the function 'list_provided_tools' implicitly registers it.""" + base_url = httpx.URL(ctl.plugin_config.wikipedia_base_url) + api_url = base_url.join("/w/api.php") + + async def _query_wikipedia( + query_type: str, query_params: dict[str, Any] + ) -> tuple[Any, ErrorResult | None]: + tcc = get_tool_call_context_async() + await tcc.notify_status(f"Fetching {query_type} from Wikipedia...") + async with httpx.AsyncClient() as web_client: + result = await web_client.get(api_url, params=query_params) + if result.status_code != httpx.codes.OK: + warning_message = f"Failed to fetch {query_type} from Wikipedia (status: {result.status_code})" + await tcc.notify_warning(warning_message) + return None, f"Error: {warning_message}" + data = result.json() + err_data = data.get("error", None) + if err_data is not None: + warning_message = f"Wikipedia API returned an error: ${err_data['info']}" + await tcc.notify_warning(warning_message) + return None, err_data + return data, None + + # Tool definitions may use any of the formats described in + # https://lmstudio.ai/docs/python/agent/tools + async def search_wikipedia(query: str) -> WikipediaSearchResult | ErrorResult: + """Searches wikipedia using the given `query` string. + + Returns a list of search results. Each search result contains + a `title`, a `summary`, and a `page_id` which can be used to + retrieve the full page content using get_wikipedia_page. + + Note: this tool searches using Wikipedia, meaning, instead of using natural language queries, + you should search for terms that you expect there will be an Wikipedia article of. For + example, if the user asks about "the inventions of Thomas Edison", don't search for "what are + the inventions of Thomas Edison". Instead, search for "Thomas Edison". + + If a particular query did not return a result that you expect, you should try to search again + using a more canonical term, or search for a different term that is more likely to contain the + relevant information. + + ALWAYS use `get_wikipedia_page` to retrieve the full content of the page afterwards. NEVER + try to answer merely based on summary in the search results. + """ + search_params = { + "action": "query", + "list": "search", + "srsearch": query, + "format": "json", + "utf8": "1", + } + data, error = await _query_wikipedia("search results", search_params) + if error is not None: + return error + raw_results = data["query"]["search"] + results = [ + WikipediaSearchEntry( + title=r["title"], + summary=_strip_search_markup(r["snippet"]), + page_id=int(r["pageid"]), + ) + for r in raw_results + ] + return WikipediaSearchResult(results=results, hint=PAGE_RETRIEVAL_HINT) + + async def get_wikipedia_page(page_id: int) -> WikipediaPage | ErrorResult: + """Retrieves the full content of a Wikipedia page using the given `page_id`. + + Returns the title and content of a page. + Use `search_wikipedia` first to get the `page_id`. + """ + str_page_id = str(page_id) + fetch_params = { + "action": "query", + "prop": "extracts", + "explaintext": "1", + "pageids": str_page_id, + "format": "json", + "utf8": "1", + } + data, error = await _query_wikipedia("page content", fetch_params) + if error is not None: + return error + raw_page = data["query"]["pages"][str_page_id] + title = raw_page["title"] + content = raw_page.get("extract", None) + if content is None: + content = "No content available for this page." + return WikipediaPage(title=title, content=content) + + return [search_wikipedia, get_wikipedia_page] + + +print(f"{__name__} initialized from {__file__}") diff --git a/src/lmstudio/plugin/__init__.py b/src/lmstudio/plugin/__init__.py index 6ba23a0..e6fe83d 100644 --- a/src/lmstudio/plugin/__init__.py +++ b/src/lmstudio/plugin/__init__.py @@ -17,7 +17,7 @@ # * refactor to allow hook invocation error handling to be common across hook invocation tasks # * [DONE] gracefully handle app termination while a dev plugin is still running # * [DONE] gracefully handle using Ctrl-C to terminate a running dev plugin -# * add async tool handling support to SDK (as part of adding .act() to the async API) +# * [DONE] add async tool handling support to SDK (as part of adding .act() to the async API) # # Controller APIs (may be limited to relevant hook controllers) # @@ -50,7 +50,7 @@ # # Tools provider hook # * [DONE] add example synchronous tool plugin (dice rolling) -# * add example asynchronous tool plugin (Wikipedia lookup) (note: requires async tool support in SDK) +# * [DONE] add example asynchronous tool plugin (Wikipedia lookup) (note: requires async tool support in SDK) # * [DONE] define the channel, hook invocation task and hook invocation controller for this hook # * [DONE] main request initiation message is "InitSession" (with Initialized/Failed responses) # * [DONE] handle "AbortToolCall" requests from server diff --git a/src/lmstudio/plugin/hooks/__init__.py b/src/lmstudio/plugin/hooks/__init__.py index 74fe37b..fa93eca 100644 --- a/src/lmstudio/plugin/hooks/__init__.py +++ b/src/lmstudio/plugin/hooks/__init__.py @@ -15,4 +15,5 @@ "ToolCallContext", "ToolsProviderController", "get_tool_call_context", + "get_tool_call_context_async", ] diff --git a/src/lmstudio/plugin/hooks/tools_provider.py b/src/lmstudio/plugin/hooks/tools_provider.py index 77c16fa..76afa67 100644 --- a/src/lmstudio/plugin/hooks/tools_provider.py +++ b/src/lmstudio/plugin/hooks/tools_provider.py @@ -59,6 +59,7 @@ "ToolsProviderHook", "run_tools_provider", "get_tool_call_context", + "get_tool_call_context_async", ]