Skip to content

Commit 05cf800

Browse files
committed
Cleanups from local PR testing
1 parent 6da4ed4 commit 05cf800

File tree

3 files changed

+109
-99
lines changed

3 files changed

+109
-99
lines changed

src/lmstudio/json_api.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
# Native in 3.11+
4242
assert_never,
4343
NoReturn,
44+
NotRequired,
4445
Self,
4546
)
4647

@@ -1090,16 +1091,11 @@ def __init__(
10901091
super().__init__(model_key, params, on_load_progress)
10911092

10921093

1093-
# Add new type definitions for inline parameter format
1094-
from typing import TypeVar
1095-
from typing_extensions import NotRequired
1096-
1097-
T = TypeVar('T')
1098-
10991094
class ToolParamDefDict(TypedDict, Generic[T]):
11001095
type: type[T]
11011096
default: NotRequired[T]
11021097

1098+
11031099
class ToolFunctionDefDict(TypedDict):
11041100
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""
11051101

@@ -1109,6 +1105,10 @@ class ToolFunctionDefDict(TypedDict):
11091105
implementation: Callable[..., Any]
11101106

11111107

1108+
# Sentinel for parameters with no defined default value
1109+
_NO_DEFAULT = object()
1110+
1111+
11121112
@dataclass(kw_only=True, frozen=True, slots=True)
11131113
class ToolFunctionDef:
11141114
"""SDK input format to specify an LLM tool call and its implementation."""
@@ -1118,30 +1118,35 @@ class ToolFunctionDef:
11181118
parameters: Mapping[str, type[Any] | ToolParamDefDict[Any]]
11191119
implementation: Callable[..., Any]
11201120

1121-
def _extract_type_and_default(self, param_name: str, param_value: type[Any] | ToolParamDefDict[Any]) -> tuple[type[Any], Any | None]:
1121+
@staticmethod
1122+
def _extract_type_and_default(
1123+
param_value: type[Any] | ToolParamDefDict[Any],
1124+
) -> tuple[type[Any], Any]:
11221125
"""Extract type and default value from parameter definition."""
1123-
if isinstance(param_value, dict) and "type" in param_value:
1126+
if isinstance(param_value, dict):
11241127
# Inline format: {"type": type, "default": value}
1125-
param_type = param_value["type"]
1126-
default_value = param_value.get("default")
1128+
param_type = param_value.get("type", None)
1129+
if param_type is None:
1130+
raise TypeError(
1131+
f"Missing 'type' key in tool parameter definition {param_value!r}"
1132+
)
1133+
default_value = param_value.get("default", _NO_DEFAULT)
11271134
return param_type, default_value
11281135
else:
11291136
# Simple format: just the type
1130-
return param_value, None
1137+
return param_value, _NO_DEFAULT
11311138

11321139
def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
11331140
params_struct_name = f"{self.name.capitalize()}Parameters"
1134-
11351141
# Build fields list with defaults
1136-
fields: list[tuple[str, Any] | tuple[str, Any, Any]] = []
1142+
fields: list[tuple[str, type[Any]] | tuple[str, type[Any], Any]] = []
11371143
for param_name, param_value in self.parameters.items():
1138-
param_type, default_value = self._extract_type_and_default(param_name, param_value)
1139-
1140-
if default_value is not None:
1141-
fields.append((param_name, param_type, default_value))
1142-
else:
1144+
param_type, default_value = self._extract_type_and_default(param_value)
1145+
if default_value is _NO_DEFAULT:
11431146
fields.append((param_name, param_type))
1144-
1147+
else:
1148+
fields.append((param_name, param_type, default_value))
1149+
# Define msgspec struct and API tool definition from the field list
11451150
params_struct = defstruct(params_struct_name, fields, kw_only=True)
11461151
return params_struct, LlmTool._from_api_dict(
11471152
{
@@ -1193,27 +1198,25 @@ def from_callable(
11931198
) from exc
11941199
# Tool definitions only annotate the input parameters, not the return type
11951200
parameters.pop("return", None)
1196-
1201+
11971202
# Extract default values from function signature and convert to inline format
11981203
try:
11991204
sig = inspect.signature(f)
1205+
except Exception:
1206+
# If we can't extract defaults, continue without them
1207+
pass
1208+
else:
12001209
for param_name, param in sig.parameters.items():
12011210
if param.default is not inspect.Parameter.empty:
12021211
# Convert to inline format: {"type": type, "default": value}
12031212
original_type = parameters[param_name]
12041213
parameters[param_name] = {
12051214
"type": original_type,
1206-
"default": param.default
1215+
"default": param.default,
12071216
}
1208-
except Exception as exc:
1209-
# If we can't extract defaults, continue without them
1210-
pass
1211-
1217+
12121218
return cls(
1213-
name=name,
1214-
description=description,
1215-
parameters=parameters,
1216-
implementation=f
1219+
name=name, description=description, parameters=parameters, implementation=f
12171220
)
12181221

12191222

@@ -1654,8 +1657,7 @@ def parse_tools(
16541657
tool_def = ToolFunctionDef.from_callable(tool)
16551658
else:
16561659
# Handle dictionary-based tool definition
1657-
tool_dict = cast(ToolFunctionDefDict, tool)
1658-
tool_def = ToolFunctionDef(**tool_dict)
1660+
tool_def = ToolFunctionDef(**tool)
16591661
if tool_def.name in client_tool_map:
16601662
raise LMStudioValueError(
16611663
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"

src/lmstudio/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
6565
for field in omit:
6666
named_schema.pop(field, None)
6767
json_schema.update(named_schema)
68-
68+
6969
# msgspec automatically handles default values in the generated JSON schema
7070
# when they are properly defined in the Struct fields
71-
71+
7272
return json_schema
7373

7474

tests/test_default_values.py

Lines changed: 74 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
"""Tests for default parameter values in tool definitions."""
22

33
import pytest
4-
from typing import Any
4+
55
from msgspec import defstruct
66

7-
from src.lmstudio.json_api import ToolFunctionDef, ToolFunctionDefDict
8-
from src.lmstudio.schemas import _to_json_schema
7+
from lmstudio.json_api import _NO_DEFAULT, ToolFunctionDef, ToolFunctionDefDict
8+
from lmstudio.schemas import _to_json_schema
99

1010

1111
def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
1212
"""Greet someone with a customizable message.
13-
13+
1414
Args:
1515
name: The name of the person to greet
1616
greeting: The greeting word to use (default: "Hello")
1717
punctuation: The punctuation to end with (default: "!")
18-
18+
1919
Returns:
2020
A greeting message
2121
"""
@@ -24,11 +24,11 @@ def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
2424

2525
def calculate(expression: str, precision: int = 2) -> str:
2626
"""Calculate a mathematical expression.
27-
27+
2828
Args:
2929
expression: The mathematical expression to evaluate
3030
precision: Number of decimal places (default: 2)
31-
31+
3232
Returns:
3333
The calculated result as a string
3434
"""
@@ -38,130 +38,138 @@ def calculate(expression: str, precision: int = 2) -> str:
3838
class TestDefaultValues:
3939
"""Test cases for default parameter values in tool definitions."""
4040

41-
def test_extract_defaults_from_callable(self):
41+
def test_extract_defaults_from_callable(self) -> None:
4242
"""Test extracting default values from a callable function."""
4343
tool_def = ToolFunctionDef.from_callable(greet)
44-
44+
4545
assert tool_def.name == "greet"
4646
# Check that defaults are converted to inline format
4747
assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"}
4848
assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"}
49-
assert tool_def.parameters["name"] == str # No default, just type
50-
51-
def test_manual_inline_defaults(self):
49+
assert tool_def.parameters["name"] is str # No default, just type
50+
51+
def test_manual_inline_defaults(self) -> None:
5252
"""Test manually specifying default values in inline format."""
5353
tool_def = ToolFunctionDef(
5454
name="calculate",
5555
description="Calculate a mathematical expression",
56-
parameters={
57-
"expression": str,
58-
"precision": {"type": int, "default": 2}
59-
},
60-
implementation=calculate
56+
parameters={"expression": str, "precision": {"type": int, "default": 2}},
57+
implementation=calculate,
6158
)
62-
59+
6360
# Check that the inline format is preserved
6461
assert tool_def.parameters["precision"] == {"type": int, "default": 2}
65-
assert tool_def.parameters["expression"] == str # No default, just type
66-
67-
def test_json_schema_with_defaults(self):
62+
assert tool_def.parameters["expression"] is str # No default, just type
63+
64+
def test_json_schema_with_defaults(self) -> None:
6865
"""Test that JSON schema includes default values."""
6966
tool_def = ToolFunctionDef.from_callable(greet)
7067
params_struct, _ = tool_def._to_llm_tool_def()
71-
68+
7269
json_schema = _to_json_schema(params_struct)
73-
70+
7471
# Check that default values are included in the schema
7572
assert json_schema["properties"]["greeting"]["default"] == "Hello"
7673
assert json_schema["properties"]["punctuation"]["default"] == "!"
7774
assert "default" not in json_schema["properties"]["name"]
78-
79-
def test_dict_based_definition(self):
75+
76+
def test_dict_based_definition(self) -> None:
8077
"""Test dictionary-based tool definition with inline defaults."""
8178
dict_tool: ToolFunctionDefDict = {
8279
"name": "format_text",
8380
"description": "Format text with specified style",
8481
"parameters": {
8582
"text": str,
8683
"style": {"type": str, "default": "normal"},
87-
"uppercase": {"type": bool, "default": False}
84+
"uppercase": {"type": bool, "default": False},
8885
},
89-
"implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text
86+
"implementation": lambda text, style="normal", uppercase=False: text.upper()
87+
if uppercase
88+
else text,
9089
}
91-
90+
9291
# This should work without errors
9392
tool_def = ToolFunctionDef(**dict_tool)
9493
assert tool_def.parameters["style"] == {"type": str, "default": "normal"}
9594
assert tool_def.parameters["uppercase"] == {"type": bool, "default": False}
96-
assert tool_def.parameters["text"] == str # No default, just type
97-
98-
def test_no_defaults(self):
95+
assert tool_def.parameters["text"] is str # No default, just type
96+
97+
def test_no_defaults(self) -> None:
9998
"""Test function with no default values."""
99+
100100
def no_defaults(a: int, b: str) -> str:
101101
"""Function with no default parameters."""
102102
return f"{a}: {b}"
103-
103+
104104
tool_def = ToolFunctionDef.from_callable(no_defaults)
105105
# All parameters should be simple types without defaults
106-
assert tool_def.parameters["a"] == int
107-
assert tool_def.parameters["b"] == str
108-
106+
assert tool_def.parameters["a"] is int
107+
assert tool_def.parameters["b"] is str
108+
109109
params_struct, _ = tool_def._to_llm_tool_def()
110110
json_schema = _to_json_schema(params_struct)
111-
111+
112112
# No default values should be present
113113
assert "default" not in json_schema["properties"]["a"]
114114
assert "default" not in json_schema["properties"]["b"]
115-
116-
def test_mixed_defaults(self):
115+
116+
def test_mixed_defaults(self) -> None:
117117
"""Test function with some parameters having defaults and others not."""
118-
def mixed_defaults(required: str, optional1: int = 42, optional2: bool = True) -> str:
118+
119+
def mixed_defaults(
120+
required: str, optional1: int = 42, optional2: bool = True
121+
) -> str:
119122
"""Function with mixed required and optional parameters."""
120123
return f"{required}: {optional1}, {optional2}"
121-
124+
122125
tool_def = ToolFunctionDef.from_callable(mixed_defaults)
123126
# Check inline format for parameters with defaults
124127
assert tool_def.parameters["optional1"] == {"type": int, "default": 42}
125128
assert tool_def.parameters["optional2"] == {"type": bool, "default": True}
126-
assert tool_def.parameters["required"] == str # No default, just type
127-
129+
assert tool_def.parameters["required"] is str # No default, just type
130+
128131
params_struct, _ = tool_def._to_llm_tool_def()
129132
json_schema = _to_json_schema(params_struct)
130-
133+
131134
# Check that default values are correctly included in schema
132135
assert json_schema["properties"]["optional1"]["default"] == 42
133136
assert json_schema["properties"]["optional2"]["default"] is True
134137
assert "default" not in json_schema["properties"]["required"]
135-
136-
def test_extract_type_and_default_method(self):
138+
139+
def test_extract_type_and_default_method(self) -> None:
137140
"""Test the _extract_type_and_default helper method."""
138-
tool_def = ToolFunctionDef(
139-
name="test",
140-
description="Test tool",
141-
parameters={
142-
"simple": str,
143-
"with_default": {"type": int, "default": 42},
144-
"complex_default": {"type": list, "default": [1, 2, 3]}
145-
},
146-
implementation=lambda x, y, z: None
147-
)
148-
141+
149142
# Test simple type
150-
param_type, default = tool_def._extract_type_and_default("simple", str)
151-
assert param_type == str
152-
assert default is None
153-
143+
param_type, default = ToolFunctionDef._extract_type_and_default(str)
144+
assert param_type is str
145+
assert default is _NO_DEFAULT
146+
147+
# Test inline format with missing type key
148+
with pytest.raises(TypeError, match="Missing 'type' key"):
149+
param_type, default = ToolFunctionDef._extract_type_and_default(
150+
{"default": 42} # type: ignore[arg-type]
151+
)
152+
153+
# Test inline format with no default
154+
param_type, default = ToolFunctionDef._extract_type_and_default({"type": int})
155+
assert param_type is int
156+
assert default is _NO_DEFAULT
157+
154158
# Test inline format with default
155-
param_type, default = tool_def._extract_type_and_default("with_default", {"type": int, "default": 42})
156-
assert param_type == int
159+
param_type, default = ToolFunctionDef._extract_type_and_default(
160+
{"type": int, "default": 42}
161+
)
162+
assert param_type is int
157163
assert default == 42
158-
164+
159165
# Test complex default
160-
param_type, default = tool_def._extract_type_and_default("complex_default", {"type": list, "default": [1, 2, 3]})
161-
assert param_type == list
166+
param_type, default = ToolFunctionDef._extract_type_and_default(
167+
{"type": list, "default": [1, 2, 3]}
168+
)
169+
assert param_type is list
162170
assert default == [1, 2, 3]
163171

164-
def test_msgspec_auto_defaults(self):
172+
def test_msgspec_auto_defaults(self) -> None:
165173
"""msgspec automatically reflects default values in the JSON schema."""
166174
TestStruct = defstruct(
167175
"TestStruct",

0 commit comments

Comments
 (0)