Skip to content

Commit c5cee65

Browse files
Implement inline default format for tool parameters - Replace separate parameter_defaults with inline format - Use generic TypedDict with NotRequired for better type safety - Remove manual default value injection, rely on msgspec auto-handling - Update all tests to use inline format instead of separate mapping
1 parent 0eb7ddd commit c5cee65

File tree

3 files changed

+136
-101
lines changed

3 files changed

+136
-101
lines changed

src/lmstudio/json_api.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,14 +1090,23 @@ def __init__(
10901090
super().__init__(model_key, params, on_load_progress)
10911091

10921092

1093+
# Add new type definitions for inline parameter format
1094+
from typing import TypeVar
1095+
from typing_extensions import NotRequired
1096+
1097+
T = TypeVar('T')
1098+
1099+
class ToolParamDefDict(TypedDict, Generic[T]):
1100+
type: type[T]
1101+
default: NotRequired[T]
1102+
10931103
class ToolFunctionDefDict(TypedDict, total=False):
10941104
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""
10951105

10961106
name: str
10971107
description: str
1098-
parameters: Mapping[str, Any]
1108+
parameters: Mapping[str, type[Any] | ToolParamDefDict[Any]]
10991109
implementation: Callable[..., Any]
1100-
parameter_defaults: Mapping[str, Any]
11011110

11021111

11031112
@dataclass(kw_only=True, frozen=True, slots=True)
@@ -1106,18 +1115,29 @@ class ToolFunctionDef:
11061115

11071116
name: str
11081117
description: str
1109-
parameters: Mapping[str, Any]
1118+
parameters: Mapping[str, type[Any] | ToolParamDefDict[Any]]
11101119
implementation: Callable[..., Any]
1111-
parameter_defaults: Mapping[str, Any] = field(default_factory=dict)
1120+
1121+
def _extract_type_and_default(self, param_name: str, param_value: type[Any] | ToolParamDefDict[Any]) -> tuple[type[Any], Any | None]:
1122+
"""Extract type and default value from parameter definition."""
1123+
if isinstance(param_value, dict) and "type" in param_value:
1124+
# Inline format: {"type": type, "default": value}
1125+
param_type = param_value["type"]
1126+
default_value = param_value.get("default")
1127+
return param_type, default_value
1128+
else:
1129+
# Simple format: just the type
1130+
return param_value, None
11121131

11131132
def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
11141133
params_struct_name = f"{self.name.capitalize()}Parameters"
11151134

11161135
# Build fields list with defaults
11171136
fields: list[tuple[str, Any] | tuple[str, Any, Any]] = []
1118-
for param_name, param_type in self.parameters.items():
1119-
if param_name in self.parameter_defaults:
1120-
default_value = self.parameter_defaults[param_name]
1137+
for param_name, param_value in self.parameters.items():
1138+
param_type, default_value = self._extract_type_and_default(param_name, param_value)
1139+
1140+
if default_value is not None:
11211141
fields.append((param_name, param_type, default_value))
11221142
else:
11231143
fields.append((param_name, param_type))
@@ -1174,13 +1194,17 @@ def from_callable(
11741194
# Tool definitions only annotate the input parameters, not the return type
11751195
parameters.pop("return", None)
11761196

1177-
# Extract default values from function signature
1178-
parameter_defaults: dict[str, Any] = {}
1197+
# Extract default values from function signature and convert to inline format
11791198
try:
11801199
sig = inspect.signature(f)
11811200
for param_name, param in sig.parameters.items():
11821201
if param.default is not inspect.Parameter.empty:
1183-
parameter_defaults[param_name] = param.default
1202+
# Convert to inline format: {"type": type, "default": value}
1203+
original_type = parameters[param_name]
1204+
parameters[param_name] = {
1205+
"type": original_type,
1206+
"default": param.default
1207+
}
11841208
except Exception as exc:
11851209
# If we can't extract defaults, continue without them
11861210
pass
@@ -1189,8 +1213,7 @@ def from_callable(
11891213
name=name,
11901214
description=description,
11911215
parameters=parameters,
1192-
implementation=f,
1193-
parameter_defaults=parameter_defaults
1216+
implementation=f
11941217
)
11951218

11961219

@@ -1632,18 +1655,7 @@ def parse_tools(
16321655
else:
16331656
# Handle dictionary-based tool definition
16341657
tool_dict = cast(ToolFunctionDefDict, tool)
1635-
name = cast(str, tool_dict["name"])
1636-
description = cast(str, tool_dict["description"])
1637-
parameters = cast(Mapping[str, Any], tool_dict["parameters"])
1638-
implementation = cast(Callable[..., Any], tool_dict["implementation"])
1639-
parameter_defaults = cast(Mapping[str, Any], tool_dict.get("parameter_defaults", {}))
1640-
tool_def = ToolFunctionDef(
1641-
name=name,
1642-
description=description,
1643-
parameters=parameters,
1644-
implementation=implementation,
1645-
parameter_defaults=parameter_defaults
1646-
)
1658+
tool_def = ToolFunctionDef(**tool_dict)
16471659
if tool_def.name in client_tool_map:
16481660
raise LMStudioValueError(
16491661
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"

src/lmstudio/schemas.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,8 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
6666
named_schema.pop(field, None)
6767
json_schema.update(named_schema)
6868

69-
# Add default values to properties if they exist in the msgspec Struct
70-
if hasattr(cls, "__struct_fields__") and hasattr(cls, "__struct_defaults__"):
71-
properties = json_schema.get("properties", {})
72-
if properties:
73-
# Get ordered field names and default values
74-
field_names = cls.__struct_fields__
75-
default_values = cls.__struct_defaults__
76-
77-
# Map default values to field names by position
78-
# Only fields with defaults will have entries in default_values
79-
default_count = len(default_values)
80-
field_count = len(field_names)
81-
82-
# For kw_only=True structs, default values correspond to the last N fields
83-
# where N is the number of default values
84-
for i, field_name in enumerate(field_names):
85-
if field_name in properties:
86-
# Calculate the index into default_values
87-
default_index = i - (field_count - default_count)
88-
if 0 <= default_index < default_count:
89-
properties[field_name]["default"] = default_values[default_index]
69+
# msgspec automatically handles default values in the generated JSON schema
70+
# when they are properly defined in the Struct fields
9071

9172
return json_schema
9273

tests/test_default_values.py

Lines changed: 98 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from typing import Any
5+
from msgspec import defstruct
56

67
from src.lmstudio.json_api import ToolFunctionDef, ToolFunctionDefDict
78
from src.lmstudio.schemas import _to_json_schema
@@ -29,113 +30,154 @@ def calculate(expression: str, precision: int = 2) -> str:
2930
precision: Number of decimal places (default: 2)
3031
3132
Returns:
32-
The calculated result
33+
The calculated result as a string
3334
"""
3435
return f"Result: {eval(expression):.{precision}f}"
3536

3637

3738
class TestDefaultValues:
38-
"""Test default parameter value functionality."""
39-
39+
"""Test cases for default parameter values in tool definitions."""
40+
4041
def test_extract_defaults_from_callable(self):
41-
"""Test extracting default values from function signature."""
42+
"""Test extracting default values from a callable function."""
4243
tool_def = ToolFunctionDef.from_callable(greet)
4344

4445
assert tool_def.name == "greet"
45-
assert tool_def.parameter_defaults == {
46-
"greeting": "Hello",
47-
"punctuation": "!"
48-
}
49-
assert "name" not in tool_def.parameter_defaults
46+
# Check that defaults are converted to inline format
47+
assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"}
48+
assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"}
49+
assert tool_def.parameters["name"] == str # No default, just type
5050

51-
def test_manual_defaults(self):
52-
"""Test manually specifying default values."""
51+
def test_manual_inline_defaults(self):
52+
"""Test manually specifying default values in inline format."""
5353
tool_def = ToolFunctionDef(
5454
name="calculate",
5555
description="Calculate a mathematical expression",
56-
parameters={"expression": str, "precision": int},
57-
implementation=calculate,
58-
parameter_defaults={"precision": 2}
56+
parameters={
57+
"expression": str,
58+
"precision": {"type": int, "default": 2}
59+
},
60+
implementation=calculate
5961
)
6062

61-
assert tool_def.parameter_defaults == {"precision": 2}
62-
assert "expression" not in tool_def.parameter_defaults
63+
# Check that the inline format is preserved
64+
assert tool_def.parameters["precision"] == {"type": int, "default": 2}
65+
assert tool_def.parameters["expression"] == str # No default, just type
6366

6467
def test_json_schema_with_defaults(self):
65-
"""Test that JSON Schema includes default values."""
68+
"""Test that JSON schema includes default values."""
6669
tool_def = ToolFunctionDef.from_callable(greet)
6770
params_struct, _ = tool_def._to_llm_tool_def()
68-
json_schema = _to_json_schema(params_struct)
69-
70-
properties = json_schema["properties"]
71-
72-
# name should not have a default (required parameter)
73-
assert "name" in properties
74-
assert "default" not in properties["name"]
7571

76-
# greeting should have default "Hello"
77-
assert "greeting" in properties
78-
assert properties["greeting"]["default"] == "Hello"
79-
80-
# punctuation should have default "!"
81-
assert "punctuation" in properties
82-
assert properties["punctuation"]["default"] == "!"
72+
json_schema = _to_json_schema(params_struct)
8373

84-
# Only name should be required
85-
assert json_schema["required"] == ["name"]
74+
# Check that default values are included in the schema
75+
assert json_schema["properties"]["greeting"]["default"] == "Hello"
76+
assert json_schema["properties"]["punctuation"]["default"] == "!"
77+
assert "default" not in json_schema["properties"]["name"]
8678

8779
def test_dict_based_definition(self):
88-
"""Test dictionary-based tool definition with defaults."""
80+
"""Test dictionary-based tool definition with inline defaults."""
8981
dict_tool: ToolFunctionDefDict = {
9082
"name": "format_text",
9183
"description": "Format text with specified style",
92-
"parameters": {"text": str, "style": str, "uppercase": bool},
93-
"implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text,
94-
"parameter_defaults": {"style": "normal", "uppercase": False}
84+
"parameters": {
85+
"text": str,
86+
"style": {"type": str, "default": "normal"},
87+
"uppercase": {"type": bool, "default": False}
88+
},
89+
"implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text
9590
}
9691

9792
# This should work without errors
9893
tool_def = ToolFunctionDef(**dict_tool)
99-
assert tool_def.parameter_defaults == {"style": "normal", "uppercase": False}
94+
assert tool_def.parameters["style"] == {"type": str, "default": "normal"}
95+
assert tool_def.parameters["uppercase"] == {"type": bool, "default": False}
96+
assert tool_def.parameters["text"] == str # No default, just type
10097

10198
def test_no_defaults(self):
10299
"""Test function with no default values."""
103100
def no_defaults(a: int, b: str) -> str:
104101
"""Function with no default parameters."""
105-
return f"{a}{b}"
102+
return f"{a}: {b}"
106103

107104
tool_def = ToolFunctionDef.from_callable(no_defaults)
108-
assert tool_def.parameter_defaults == {}
105+
# All parameters should be simple types without defaults
106+
assert tool_def.parameters["a"] == int
107+
assert tool_def.parameters["b"] == str
109108

110109
params_struct, _ = tool_def._to_llm_tool_def()
111110
json_schema = _to_json_schema(params_struct)
112111

113-
# All parameters should be required
114-
assert json_schema["required"] == ["a", "b"]
115-
116-
# No properties should have defaults
117-
for prop in json_schema["properties"].values():
118-
assert "default" not in prop
112+
# No default values should be present
113+
assert "default" not in json_schema["properties"]["a"]
114+
assert "default" not in json_schema["properties"]["b"]
119115

120116
def test_mixed_defaults(self):
121-
"""Test function with some parameters having defaults."""
117+
"""Test function with some parameters having defaults and others not."""
122118
def mixed_defaults(required: str, optional1: int = 42, optional2: bool = True) -> str:
123119
"""Function with mixed required and optional parameters."""
124-
return f"{required}{optional1}{optional2}"
120+
return f"{required}: {optional1}, {optional2}"
125121

126122
tool_def = ToolFunctionDef.from_callable(mixed_defaults)
127-
assert tool_def.parameter_defaults == {
128-
"optional1": 42,
129-
"optional2": True
130-
}
123+
# Check inline format for parameters with defaults
124+
assert tool_def.parameters["optional1"] == {"type": int, "default": 42}
125+
assert tool_def.parameters["optional2"] == {"type": bool, "default": True}
126+
assert tool_def.parameters["required"] == str # No default, just type
131127

132128
params_struct, _ = tool_def._to_llm_tool_def()
133129
json_schema = _to_json_schema(params_struct)
134130

135-
# Only required should be in required list
136-
assert json_schema["required"] == ["required"]
137-
138-
# Check defaults
131+
# Check that default values are correctly included in schema
139132
assert json_schema["properties"]["optional1"]["default"] == 42
140133
assert json_schema["properties"]["optional2"]["default"] is True
141134
assert "default" not in json_schema["properties"]["required"]
135+
136+
def test_extract_type_and_default_method(self):
137+
"""Test the _extract_type_and_default helper method."""
138+
tool_def = ToolFunctionDef(
139+
name="test",
140+
description="Test tool",
141+
parameters={
142+
"simple": str,
143+
"with_default": {"type": int, "default": 42},
144+
"complex_default": {"type": list, "default": [1, 2, 3]}
145+
},
146+
implementation=lambda x, y, z: None
147+
)
148+
149+
# Test simple type
150+
param_type, default = tool_def._extract_type_and_default("simple", str)
151+
assert param_type == str
152+
assert default is None
153+
154+
# Test inline format with default
155+
param_type, default = tool_def._extract_type_and_default("with_default", {"type": int, "default": 42})
156+
assert param_type == int
157+
assert default == 42
158+
159+
# Test complex default
160+
param_type, default = tool_def._extract_type_and_default("complex_default", {"type": list, "default": [1, 2, 3]})
161+
assert param_type == list
162+
assert default == [1, 2, 3]
163+
164+
def test_msgspec_auto_defaults(self):
165+
"""msgspec automatically reflects default values in the JSON schema."""
166+
TestStruct = defstruct(
167+
"TestStruct",
168+
[
169+
("name", str),
170+
("age", int, 18),
171+
("active", bool, True),
172+
],
173+
kw_only=True,
174+
)
175+
176+
schema = _to_json_schema(TestStruct)
177+
properties = schema.get("properties", {})
178+
required = schema.get("required", [])
179+
180+
assert "name" in properties and "default" not in properties["name"]
181+
assert properties["age"].get("default") == 18
182+
assert properties["active"].get("default") is True
183+
assert "name" in required and "age" not in required and "active" not in required

0 commit comments

Comments
 (0)