Skip to content

Commit a7cd8af

Browse files
Add default parameter value support for tool definitions (#150)
- Add support for parameter default values to ToolFunctionDef and ToolFunctionDefDict - Implement automatic default extraction from function signatures using inspect - Add comprehensive test coverage for default parameter values - Maintain backward compatibility with existing tool definitions Closes #90 --------- Co-authored-by: Alyssa Coghlan <[email protected]>
1 parent 1b28c0f commit a7cd8af

File tree

3 files changed

+267
-3
lines changed

3 files changed

+267
-3
lines changed

src/lmstudio/json_api.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
import asyncio
1313
import copy
14+
import inspect
1415
import json
16+
import sys
1517
import uuid
1618
import warnings
1719

@@ -40,6 +42,7 @@
4042
# Native in 3.11+
4143
assert_never,
4244
NoReturn,
45+
NotRequired,
4346
Self,
4447
)
4548

@@ -1089,27 +1092,75 @@ def __init__(
10891092
super().__init__(model_key, params, on_load_progress)
10901093

10911094

1095+
if sys.version_info < (3, 11):
1096+
# Generic typed dictionaries aren't supported in Python 3.10
1097+
# https://github.com/python/cpython/issues/89026
1098+
class ToolParamDefDict(TypedDict):
1099+
type: type[Any]
1100+
default: NotRequired[Any]
1101+
1102+
ParamDefDict: TypeAlias = ToolParamDefDict
1103+
else:
1104+
1105+
class ToolParamDefDict(TypedDict, Generic[T]):
1106+
type: type[T]
1107+
default: NotRequired[T]
1108+
1109+
ParamDefDict: TypeAlias = ToolParamDefDict[Any]
1110+
1111+
10921112
class ToolFunctionDefDict(TypedDict):
10931113
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""
10941114

10951115
name: str
10961116
description: str
1097-
parameters: Mapping[str, Any]
1117+
parameters: Mapping[str, type[Any] | ParamDefDict]
10981118
implementation: Callable[..., Any]
10991119

11001120

1121+
# Sentinel for parameters with no defined default value
1122+
_NO_DEFAULT = object()
1123+
1124+
11011125
@dataclass(kw_only=True, frozen=True, slots=True)
11021126
class ToolFunctionDef:
11031127
"""SDK input format to specify an LLM tool call and its implementation."""
11041128

11051129
name: str
11061130
description: str
1107-
parameters: Mapping[str, Any]
1131+
parameters: Mapping[str, type[Any] | ParamDefDict]
11081132
implementation: Callable[..., Any]
11091133

1134+
@staticmethod
1135+
def _extract_type_and_default(
1136+
param_value: type[Any] | ParamDefDict,
1137+
) -> tuple[type[Any], Any]:
1138+
"""Extract type and default value from parameter definition."""
1139+
if isinstance(param_value, dict):
1140+
# Inline format: {"type": type, "default": value}
1141+
param_type = param_value.get("type", None)
1142+
if param_type is None:
1143+
raise TypeError(
1144+
f"Missing 'type' key in tool parameter definition {param_value!r}"
1145+
)
1146+
default_value = param_value.get("default", _NO_DEFAULT)
1147+
return param_type, default_value
1148+
else:
1149+
# Simple format: just the type
1150+
return param_value, _NO_DEFAULT
1151+
11101152
def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
11111153
params_struct_name = f"{self.name.capitalize()}Parameters"
1112-
params_struct = defstruct(params_struct_name, self.parameters.items())
1154+
# Build fields list with defaults
1155+
fields: list[tuple[str, type[Any]] | tuple[str, type[Any], Any]] = []
1156+
for param_name, param_value in self.parameters.items():
1157+
param_type, default_value = self._extract_type_and_default(param_value)
1158+
if default_value is _NO_DEFAULT:
1159+
fields.append((param_name, param_type))
1160+
else:
1161+
fields.append((param_name, param_type, default_value))
1162+
# Define msgspec struct and API tool definition from the field list
1163+
params_struct = defstruct(params_struct_name, fields, kw_only=True)
11131164
return params_struct, LlmTool._from_api_dict(
11141165
{
11151166
"type": "function",
@@ -1160,6 +1211,23 @@ def from_callable(
11601211
) from exc
11611212
# Tool definitions only annotate the input parameters, not the return type
11621213
parameters.pop("return", None)
1214+
1215+
# Extract default values from function signature and convert to inline format
1216+
try:
1217+
sig = inspect.signature(f)
1218+
except Exception:
1219+
# If we can't extract defaults, continue without them
1220+
pass
1221+
else:
1222+
for param_name, param in sig.parameters.items():
1223+
if param.default is not inspect.Parameter.empty:
1224+
# Convert to inline format: {"type": type, "default": value}
1225+
original_type = parameters[param_name]
1226+
parameters[param_name] = {
1227+
"type": original_type,
1228+
"default": param.default,
1229+
}
1230+
11631231
return cls(
11641232
name=name, description=description, parameters=parameters, implementation=f
11651233
)
@@ -1601,6 +1669,7 @@ def parse_tools(
16011669
elif callable(tool):
16021670
tool_def = ToolFunctionDef.from_callable(tool)
16031671
else:
1672+
# Handle dictionary-based tool definition
16041673
tool_def = ToolFunctionDef(**tool)
16051674
if tool_def.name in client_tool_map:
16061675
raise LMStudioValueError(

src/lmstudio/schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
6565
for field in omit:
6666
named_schema.pop(field, None)
6767
json_schema.update(named_schema)
68+
69+
# msgspec automatically handles default values in the generated JSON schema
70+
# when they are properly defined in the Struct fields
71+
6872
return json_schema
6973

7074

tests/test_default_values.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Tests for default parameter values in tool definitions."""
2+
3+
import pytest
4+
5+
from msgspec import defstruct
6+
7+
from lmstudio.json_api import _NO_DEFAULT, ToolFunctionDef, ToolFunctionDefDict
8+
from lmstudio.schemas import _to_json_schema
9+
10+
11+
def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
12+
"""Greet someone with a customizable message.
13+
14+
Args:
15+
name: The name of the person to greet
16+
greeting: The greeting word to use (default: "Hello")
17+
punctuation: The punctuation to end with (default: "!")
18+
19+
Returns:
20+
A greeting message
21+
"""
22+
return f"{greeting}, {name}{punctuation}"
23+
24+
25+
def calculate(expression: str, precision: int = 2) -> str:
26+
"""Calculate a mathematical expression.
27+
28+
Args:
29+
expression: The mathematical expression to evaluate
30+
precision: Number of decimal places (default: 2)
31+
32+
Returns:
33+
The calculated result as a string
34+
"""
35+
return f"Result: {eval(expression):.{precision}f}"
36+
37+
38+
class TestDefaultValues:
39+
"""Test cases for default parameter values in tool definitions."""
40+
41+
def test_extract_defaults_from_callable(self) -> None:
42+
"""Test extracting default values from a callable function."""
43+
tool_def = ToolFunctionDef.from_callable(greet)
44+
45+
assert tool_def.name == "greet"
46+
# Check that defaults are converted to inline format
47+
assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"}
48+
assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"}
49+
assert tool_def.parameters["name"] is str # No default, just type
50+
51+
def test_manual_inline_defaults(self) -> None:
52+
"""Test manually specifying default values in inline format."""
53+
tool_def = ToolFunctionDef(
54+
name="calculate",
55+
description="Calculate a mathematical expression",
56+
parameters={"expression": str, "precision": {"type": int, "default": 2}},
57+
implementation=calculate,
58+
)
59+
60+
# Check that the inline format is preserved
61+
assert tool_def.parameters["precision"] == {"type": int, "default": 2}
62+
assert tool_def.parameters["expression"] is str # No default, just type
63+
64+
def test_json_schema_with_defaults(self) -> None:
65+
"""Test that JSON schema includes default values."""
66+
tool_def = ToolFunctionDef.from_callable(greet)
67+
params_struct, _ = tool_def._to_llm_tool_def()
68+
69+
json_schema = _to_json_schema(params_struct)
70+
71+
# Check that default values are included in the schema
72+
assert json_schema["properties"]["greeting"]["default"] == "Hello"
73+
assert json_schema["properties"]["punctuation"]["default"] == "!"
74+
assert "default" not in json_schema["properties"]["name"]
75+
76+
def test_dict_based_definition(self) -> None:
77+
"""Test dictionary-based tool definition with inline defaults."""
78+
dict_tool: ToolFunctionDefDict = {
79+
"name": "format_text",
80+
"description": "Format text with specified style",
81+
"parameters": {
82+
"text": str,
83+
"style": {"type": str, "default": "normal"},
84+
"uppercase": {"type": bool, "default": False},
85+
},
86+
"implementation": lambda text, style="normal", uppercase=False: text.upper()
87+
if uppercase
88+
else text,
89+
}
90+
91+
# This should work without errors
92+
tool_def = ToolFunctionDef(**dict_tool)
93+
assert tool_def.parameters["style"] == {"type": str, "default": "normal"}
94+
assert tool_def.parameters["uppercase"] == {"type": bool, "default": False}
95+
assert tool_def.parameters["text"] is str # No default, just type
96+
97+
def test_no_defaults(self) -> None:
98+
"""Test function with no default values."""
99+
100+
def no_defaults(a: int, b: str) -> str:
101+
"""Function with no default parameters."""
102+
return f"{a}: {b}"
103+
104+
tool_def = ToolFunctionDef.from_callable(no_defaults)
105+
# All parameters should be simple types without defaults
106+
assert tool_def.parameters["a"] is int
107+
assert tool_def.parameters["b"] is str
108+
109+
params_struct, _ = tool_def._to_llm_tool_def()
110+
json_schema = _to_json_schema(params_struct)
111+
112+
# No default values should be present
113+
assert "default" not in json_schema["properties"]["a"]
114+
assert "default" not in json_schema["properties"]["b"]
115+
116+
def test_mixed_defaults(self) -> None:
117+
"""Test function with some parameters having defaults and others not."""
118+
119+
def mixed_defaults(
120+
required: str, optional1: int = 42, optional2: bool = True
121+
) -> str:
122+
"""Function with mixed required and optional parameters."""
123+
return f"{required}: {optional1}, {optional2}"
124+
125+
tool_def = ToolFunctionDef.from_callable(mixed_defaults)
126+
# Check inline format for parameters with defaults
127+
assert tool_def.parameters["optional1"] == {"type": int, "default": 42}
128+
assert tool_def.parameters["optional2"] == {"type": bool, "default": True}
129+
assert tool_def.parameters["required"] is str # No default, just type
130+
131+
params_struct, _ = tool_def._to_llm_tool_def()
132+
json_schema = _to_json_schema(params_struct)
133+
134+
# Check that default values are correctly included in schema
135+
assert json_schema["properties"]["optional1"]["default"] == 42
136+
assert json_schema["properties"]["optional2"]["default"] is True
137+
assert "default" not in json_schema["properties"]["required"]
138+
139+
def test_extract_type_and_default_method(self) -> None:
140+
"""Test the _extract_type_and_default helper method."""
141+
142+
# Test simple type
143+
param_type, default = ToolFunctionDef._extract_type_and_default(str)
144+
assert param_type is str
145+
assert default is _NO_DEFAULT
146+
147+
# Test inline format with missing type key
148+
with pytest.raises(TypeError, match="Missing 'type' key"):
149+
param_type, default = ToolFunctionDef._extract_type_and_default(
150+
{"default": 42} # type: ignore[arg-type]
151+
)
152+
153+
# Test inline format with no default
154+
param_type, default = ToolFunctionDef._extract_type_and_default({"type": int})
155+
assert param_type is int
156+
assert default is _NO_DEFAULT
157+
158+
# Test inline format with default
159+
param_type, default = ToolFunctionDef._extract_type_and_default(
160+
{"type": int, "default": 42}
161+
)
162+
assert param_type is int
163+
assert default == 42
164+
165+
# Test complex default
166+
param_type, default = ToolFunctionDef._extract_type_and_default(
167+
{"type": list, "default": [1, 2, 3]}
168+
)
169+
assert param_type is list
170+
assert default == [1, 2, 3]
171+
172+
def test_msgspec_auto_defaults(self) -> None:
173+
"""msgspec automatically reflects default values in the JSON schema."""
174+
TestStruct = defstruct(
175+
"TestStruct",
176+
[
177+
("name", str),
178+
("age", int, 18),
179+
("active", bool, True),
180+
],
181+
kw_only=True,
182+
)
183+
184+
schema = _to_json_schema(TestStruct)
185+
properties = schema.get("properties", {})
186+
required = schema.get("required", [])
187+
188+
assert "name" in properties and "default" not in properties["name"]
189+
assert properties["age"].get("default") == 18
190+
assert properties["active"].get("default") is True
191+
assert "name" in required and "age" not in required and "active" not in required

0 commit comments

Comments
 (0)