Skip to content

Commit 0eb7ddd

Browse files
feat: Add default parameter values support for tool definitions
- Add parameter_defaults field to ToolFunctionDef and ToolFunctionDefDict - Implement automatic default extraction from function signatures using inspect - Add manual default specification support - Update JSON schema generation to include default values - Add comprehensive test coverage for default parameter values - Maintain backward compatibility with existing tool definitions Closes #90
1 parent 3a81ccb commit 0eb7ddd

File tree

3 files changed

+210
-4
lines changed

3 files changed

+210
-4
lines changed

src/lmstudio/json_api.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import asyncio
1313
import copy
14+
import inspect
1415
import json
1516
import uuid
1617
import warnings
@@ -1089,13 +1090,14 @@ def __init__(
10891090
super().__init__(model_key, params, on_load_progress)
10901091

10911092

1092-
class ToolFunctionDefDict(TypedDict):
1093+
class ToolFunctionDefDict(TypedDict, total=False):
10931094
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""
10941095

10951096
name: str
10961097
description: str
10971098
parameters: Mapping[str, Any]
10981099
implementation: Callable[..., Any]
1100+
parameter_defaults: Mapping[str, Any]
10991101

11001102

11011103
@dataclass(kw_only=True, frozen=True, slots=True)
@@ -1106,10 +1108,21 @@ class ToolFunctionDef:
11061108
description: str
11071109
parameters: Mapping[str, Any]
11081110
implementation: Callable[..., Any]
1111+
parameter_defaults: Mapping[str, Any] = field(default_factory=dict)
11091112

11101113
def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
11111114
params_struct_name = f"{self.name.capitalize()}Parameters"
1112-
params_struct = defstruct(params_struct_name, self.parameters.items())
1115+
1116+
# Build fields list with defaults
1117+
fields: list[tuple[str, Any] | tuple[str, Any, Any]] = []
1118+
for param_name, param_type in self.parameters.items():
1119+
if param_name in self.parameter_defaults:
1120+
default_value = self.parameter_defaults[param_name]
1121+
fields.append((param_name, param_type, default_value))
1122+
else:
1123+
fields.append((param_name, param_type))
1124+
1125+
params_struct = defstruct(params_struct_name, fields, kw_only=True)
11131126
return params_struct, LlmTool._from_api_dict(
11141127
{
11151128
"type": "function",
@@ -1160,8 +1173,24 @@ def from_callable(
11601173
) from exc
11611174
# Tool definitions only annotate the input parameters, not the return type
11621175
parameters.pop("return", None)
1176+
1177+
# Extract default values from function signature
1178+
parameter_defaults: dict[str, Any] = {}
1179+
try:
1180+
sig = inspect.signature(f)
1181+
for param_name, param in sig.parameters.items():
1182+
if param.default is not inspect.Parameter.empty:
1183+
parameter_defaults[param_name] = param.default
1184+
except Exception as exc:
1185+
# If we can't extract defaults, continue without them
1186+
pass
1187+
11631188
return cls(
1164-
name=name, description=description, parameters=parameters, implementation=f
1189+
name=name,
1190+
description=description,
1191+
parameters=parameters,
1192+
implementation=f,
1193+
parameter_defaults=parameter_defaults
11651194
)
11661195

11671196

@@ -1601,7 +1630,20 @@ def parse_tools(
16011630
elif callable(tool):
16021631
tool_def = ToolFunctionDef.from_callable(tool)
16031632
else:
1604-
tool_def = ToolFunctionDef(**tool)
1633+
# Handle dictionary-based tool definition
1634+
tool_dict = cast(ToolFunctionDefDict, tool)
1635+
name = cast(str, tool_dict["name"])
1636+
description = cast(str, tool_dict["description"])
1637+
parameters = cast(Mapping[str, Any], tool_dict["parameters"])
1638+
implementation = cast(Callable[..., Any], tool_dict["implementation"])
1639+
parameter_defaults = cast(Mapping[str, Any], tool_dict.get("parameter_defaults", {}))
1640+
tool_def = ToolFunctionDef(
1641+
name=name,
1642+
description=description,
1643+
parameters=parameters,
1644+
implementation=implementation,
1645+
parameter_defaults=parameter_defaults
1646+
)
16051647
if tool_def.name in client_tool_map:
16061648
raise LMStudioValueError(
16071649
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"

src/lmstudio/schemas.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,29 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
6565
for field in omit:
6666
named_schema.pop(field, None)
6767
json_schema.update(named_schema)
68+
69+
# Add default values to properties if they exist in the msgspec Struct
70+
if hasattr(cls, "__struct_fields__") and hasattr(cls, "__struct_defaults__"):
71+
properties = json_schema.get("properties", {})
72+
if properties:
73+
# Get ordered field names and default values
74+
field_names = cls.__struct_fields__
75+
default_values = cls.__struct_defaults__
76+
77+
# Map default values to field names by position
78+
# Only fields with defaults will have entries in default_values
79+
default_count = len(default_values)
80+
field_count = len(field_names)
81+
82+
# For kw_only=True structs, default values correspond to the last N fields
83+
# where N is the number of default values
84+
for i, field_name in enumerate(field_names):
85+
if field_name in properties:
86+
# Calculate the index into default_values
87+
default_index = i - (field_count - default_count)
88+
if 0 <= default_index < default_count:
89+
properties[field_name]["default"] = default_values[default_index]
90+
6891
return json_schema
6992

7093

tests/test_default_values.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Tests for default parameter values in tool definitions."""
2+
3+
import pytest
4+
from typing import Any
5+
6+
from src.lmstudio.json_api import ToolFunctionDef, ToolFunctionDefDict
7+
from src.lmstudio.schemas import _to_json_schema
8+
9+
10+
def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
11+
"""Greet someone with a customizable message.
12+
13+
Args:
14+
name: The name of the person to greet
15+
greeting: The greeting word to use (default: "Hello")
16+
punctuation: The punctuation to end with (default: "!")
17+
18+
Returns:
19+
A greeting message
20+
"""
21+
return f"{greeting}, {name}{punctuation}"
22+
23+
24+
def calculate(expression: str, precision: int = 2) -> str:
25+
"""Calculate a mathematical expression.
26+
27+
Args:
28+
expression: The mathematical expression to evaluate
29+
precision: Number of decimal places (default: 2)
30+
31+
Returns:
32+
The calculated result
33+
"""
34+
return f"Result: {eval(expression):.{precision}f}"
35+
36+
37+
class TestDefaultValues:
38+
"""Test default parameter value functionality."""
39+
40+
def test_extract_defaults_from_callable(self):
41+
"""Test extracting default values from function signature."""
42+
tool_def = ToolFunctionDef.from_callable(greet)
43+
44+
assert tool_def.name == "greet"
45+
assert tool_def.parameter_defaults == {
46+
"greeting": "Hello",
47+
"punctuation": "!"
48+
}
49+
assert "name" not in tool_def.parameter_defaults
50+
51+
def test_manual_defaults(self):
52+
"""Test manually specifying default values."""
53+
tool_def = ToolFunctionDef(
54+
name="calculate",
55+
description="Calculate a mathematical expression",
56+
parameters={"expression": str, "precision": int},
57+
implementation=calculate,
58+
parameter_defaults={"precision": 2}
59+
)
60+
61+
assert tool_def.parameter_defaults == {"precision": 2}
62+
assert "expression" not in tool_def.parameter_defaults
63+
64+
def test_json_schema_with_defaults(self):
65+
"""Test that JSON Schema includes default values."""
66+
tool_def = ToolFunctionDef.from_callable(greet)
67+
params_struct, _ = tool_def._to_llm_tool_def()
68+
json_schema = _to_json_schema(params_struct)
69+
70+
properties = json_schema["properties"]
71+
72+
# name should not have a default (required parameter)
73+
assert "name" in properties
74+
assert "default" not in properties["name"]
75+
76+
# greeting should have default "Hello"
77+
assert "greeting" in properties
78+
assert properties["greeting"]["default"] == "Hello"
79+
80+
# punctuation should have default "!"
81+
assert "punctuation" in properties
82+
assert properties["punctuation"]["default"] == "!"
83+
84+
# Only name should be required
85+
assert json_schema["required"] == ["name"]
86+
87+
def test_dict_based_definition(self):
88+
"""Test dictionary-based tool definition with defaults."""
89+
dict_tool: ToolFunctionDefDict = {
90+
"name": "format_text",
91+
"description": "Format text with specified style",
92+
"parameters": {"text": str, "style": str, "uppercase": bool},
93+
"implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text,
94+
"parameter_defaults": {"style": "normal", "uppercase": False}
95+
}
96+
97+
# This should work without errors
98+
tool_def = ToolFunctionDef(**dict_tool)
99+
assert tool_def.parameter_defaults == {"style": "normal", "uppercase": False}
100+
101+
def test_no_defaults(self):
102+
"""Test function with no default values."""
103+
def no_defaults(a: int, b: str) -> str:
104+
"""Function with no default parameters."""
105+
return f"{a}{b}"
106+
107+
tool_def = ToolFunctionDef.from_callable(no_defaults)
108+
assert tool_def.parameter_defaults == {}
109+
110+
params_struct, _ = tool_def._to_llm_tool_def()
111+
json_schema = _to_json_schema(params_struct)
112+
113+
# All parameters should be required
114+
assert json_schema["required"] == ["a", "b"]
115+
116+
# No properties should have defaults
117+
for prop in json_schema["properties"].values():
118+
assert "default" not in prop
119+
120+
def test_mixed_defaults(self):
121+
"""Test function with some parameters having defaults."""
122+
def mixed_defaults(required: str, optional1: int = 42, optional2: bool = True) -> str:
123+
"""Function with mixed required and optional parameters."""
124+
return f"{required}{optional1}{optional2}"
125+
126+
tool_def = ToolFunctionDef.from_callable(mixed_defaults)
127+
assert tool_def.parameter_defaults == {
128+
"optional1": 42,
129+
"optional2": True
130+
}
131+
132+
params_struct, _ = tool_def._to_llm_tool_def()
133+
json_schema = _to_json_schema(params_struct)
134+
135+
# Only required should be in required list
136+
assert json_schema["required"] == ["required"]
137+
138+
# Check defaults
139+
assert json_schema["properties"]["optional1"]["default"] == 42
140+
assert json_schema["properties"]["optional2"]["default"] is True
141+
assert "default" not in json_schema["properties"]["required"]

0 commit comments

Comments
 (0)