|
1 | 1 | """Example plugin that provide dice rolling tools.""" |
2 | 2 |
|
3 | | -# Not yet implemented, currently used to check plugins with no hooks defined |
| 3 | +import time |
| 4 | + |
| 5 | +from random import randint |
| 6 | +from typing import TypedDict |
| 7 | + |
| 8 | +from lmstudio.plugin import ( |
| 9 | + BaseConfigSchema, |
| 10 | + ToolsProviderController, |
| 11 | + config_field, |
| 12 | + get_tool_call_context, |
| 13 | +) |
| 14 | +from lmstudio import ToolDefinition |
| 15 | + |
| 16 | + |
| 17 | +# Assigning ConfigSchema = SomeOtherSchemaClass also works |
| 18 | +class ConfigSchema(BaseConfigSchema): |
| 19 | + """The name 'ConfigSchema' implicitly registers this as the per-chat plugin config schema.""" |
| 20 | + |
| 21 | + enable_inplace_status_demo: bool = config_field( |
| 22 | + label="Enable in-place status demo", |
| 23 | + hint="The plugin will run an in-place task status updating demo when invoked", |
| 24 | + default=True, |
| 25 | + ) |
| 26 | + inplace_status_duration: float = config_field( |
| 27 | + label="In-place status total duration (s)", |
| 28 | + hint="The number of seconds to spend displaying the in-place task status update", |
| 29 | + default=5.0, |
| 30 | + ) |
| 31 | + restrict_die_types: bool = config_field( |
| 32 | + label="Require polyhedral dice", |
| 33 | + hint="Require conventional polyhedral dice (4, 6, 8, 10, 12, 20, or 100 sides)", |
| 34 | + default=True, |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +# This example plugin has no global configuration settings defined. |
| 39 | +# For a type hinted plugin with no configuration settings of a given type, |
| 40 | +# BaseConfigSchema may be used in the hook controller type hint. |
| 41 | +# Defining a config schema subclass with no fields is also a valid approach. |
| 42 | + |
| 43 | + |
| 44 | +# When reporting multiple values from a tool call, dictionaries |
| 45 | +# are the preferred format, as the field names allow the LLM |
| 46 | +# to potentially interpret the result correctly. |
| 47 | +# Unlike parameter details, no return value schema is sent to the server, |
| 48 | +# so relevant information needs to be part of the JSON serialisation. |
| 49 | +class DiceRollResult(TypedDict): |
| 50 | + """The result of a dice rolling request.""" |
| 51 | + |
| 52 | + rolls: list[int] |
| 53 | + total: int |
| 54 | + |
| 55 | + |
| 56 | +# Assigning list_provided_tools = some_other_callable also works |
| 57 | +async def list_provided_tools( |
| 58 | + ctl: ToolsProviderController[ConfigSchema, BaseConfigSchema], |
| 59 | +) -> list[ToolDefinition]: |
| 60 | + """Naming the function 'list_provided_tools' implicitly registers it.""" |
| 61 | + config = ctl.plugin_config |
| 62 | + if config.enable_inplace_status_demo: |
| 63 | + inplace_status_duration = config.inplace_status_duration |
| 64 | + else: |
| 65 | + inplace_status_duration = 0 |
| 66 | + if config.restrict_die_types: |
| 67 | + permitted_sides = {4, 6, 8, 10, 12, 20, 100} |
| 68 | + else: |
| 69 | + permitted_sides = None |
| 70 | + |
| 71 | + # Tool definitions may use any of the formats described in |
| 72 | + # https://lmstudio.ai/docs/python/agent/tools |
| 73 | + def roll_dice(count: int, sides: int) -> DiceRollResult: |
| 74 | + """Roll a specified number of dice with specified number of faces. |
| 75 | +
|
| 76 | + For example, to roll 2 six-sided dice (i.e. 2d6), you should call the function |
| 77 | + `roll_dice` with the parameters { count: 2, sides: 6 }. |
| 78 | + """ |
| 79 | + if inplace_status_duration: |
| 80 | + tcc = get_tool_call_context() |
| 81 | + status_updates = ( |
| 82 | + (tcc.notify_status, "Display status update in UI."), |
| 83 | + (tcc.notify_warning, "Display task warning in UI."), |
| 84 | + (tcc.notify_status, "Post-warning status update in UI."), |
| 85 | + ) |
| 86 | + status_duration = inplace_status_duration / len(status_updates) |
| 87 | + for send_notification, status_text in status_updates: |
| 88 | + time.sleep(status_duration) |
| 89 | + send_notification(status_text) |
| 90 | + if permitted_sides and sides not in permitted_sides: |
| 91 | + expected_die_types = ",".join(map(str, sorted(permitted_sides))) |
| 92 | + err_msg = f"{sides} is not a conventional polyhedral die type ({expected_die_types})" |
| 93 | + raise ValueError(err_msg) |
| 94 | + rolls = [randint(1, sides) for _ in range(count)] |
| 95 | + return DiceRollResult(rolls=rolls, total=sum(rolls)) |
| 96 | + |
| 97 | + return [roll_dice] |
| 98 | + |
| 99 | + |
| 100 | +print(f"{__name__} initialized from {__file__}") |
0 commit comments