Skip to content
75 changes: 72 additions & 3 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import asyncio
import copy
import inspect
import json
import sys
import uuid
import warnings

Expand Down Expand Up @@ -40,6 +42,7 @@
# Native in 3.11+
assert_never,
NoReturn,
NotRequired,
Self,
)

Expand Down Expand Up @@ -1089,27 +1092,75 @@ def __init__(
super().__init__(model_key, params, on_load_progress)


if sys.version_info < (3, 11):
# Generic typed dictionaries aren't supported in Python 3.10
# https://github.com/python/cpython/issues/89026
class ToolParamDefDict(TypedDict):
type: type[Any]
default: NotRequired[Any]

ParamDefDict: TypeAlias = ToolParamDefDict
else:

class ToolParamDefDict(TypedDict, Generic[T]):
type: type[T]
default: NotRequired[T]

ParamDefDict: TypeAlias = ToolParamDefDict[Any]


class ToolFunctionDefDict(TypedDict):
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""

name: str
description: str
parameters: Mapping[str, Any]
parameters: Mapping[str, type[Any] | ParamDefDict]
implementation: Callable[..., Any]


# Sentinel for parameters with no defined default value
_NO_DEFAULT = object()


@dataclass(kw_only=True, frozen=True, slots=True)
class ToolFunctionDef:
"""SDK input format to specify an LLM tool call and its implementation."""

name: str
description: str
parameters: Mapping[str, Any]
parameters: Mapping[str, type[Any] | ParamDefDict]
implementation: Callable[..., Any]

@staticmethod
def _extract_type_and_default(
param_value: type[Any] | ParamDefDict,
) -> tuple[type[Any], Any]:
"""Extract type and default value from parameter definition."""
if isinstance(param_value, dict):
# Inline format: {"type": type, "default": value}
param_type = param_value.get("type", None)
if param_type is None:
raise TypeError(
f"Missing 'type' key in tool parameter definition {param_value!r}"
)
default_value = param_value.get("default", _NO_DEFAULT)
return param_type, default_value
else:
# Simple format: just the type
return param_value, _NO_DEFAULT

def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
params_struct_name = f"{self.name.capitalize()}Parameters"
params_struct = defstruct(params_struct_name, self.parameters.items())
# Build fields list with defaults
fields: list[tuple[str, type[Any]] | tuple[str, type[Any], Any]] = []
for param_name, param_value in self.parameters.items():
param_type, default_value = self._extract_type_and_default(param_value)
if default_value is _NO_DEFAULT:
fields.append((param_name, param_type))
else:
fields.append((param_name, param_type, default_value))
# Define msgspec struct and API tool definition from the field list
params_struct = defstruct(params_struct_name, fields, kw_only=True)
return params_struct, LlmTool._from_api_dict(
{
"type": "function",
Expand Down Expand Up @@ -1160,6 +1211,23 @@ def from_callable(
) from exc
# Tool definitions only annotate the input parameters, not the return type
parameters.pop("return", None)

# Extract default values from function signature and convert to inline format
try:
sig = inspect.signature(f)
except Exception:
# If we can't extract defaults, continue without them
pass
else:
for param_name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty:
# Convert to inline format: {"type": type, "default": value}
original_type = parameters[param_name]
parameters[param_name] = {
"type": original_type,
"default": param.default,
}

return cls(
name=name, description=description, parameters=parameters, implementation=f
)
Expand Down Expand Up @@ -1601,6 +1669,7 @@ def parse_tools(
elif callable(tool):
tool_def = ToolFunctionDef.from_callable(tool)
else:
# Handle dictionary-based tool definition
tool_def = ToolFunctionDef(**tool)
if tool_def.name in client_tool_map:
raise LMStudioValueError(
Expand Down
4 changes: 4 additions & 0 deletions src/lmstudio/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
for field in omit:
named_schema.pop(field, None)
json_schema.update(named_schema)

# msgspec automatically handles default values in the generated JSON schema
# when they are properly defined in the Struct fields

return json_schema


Expand Down
191 changes: 191 additions & 0 deletions tests/test_default_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Tests for default parameter values in tool definitions."""

import pytest

from msgspec import defstruct

from lmstudio.json_api import _NO_DEFAULT, ToolFunctionDef, ToolFunctionDefDict
from lmstudio.schemas import _to_json_schema


def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
"""Greet someone with a customizable message.

Args:
name: The name of the person to greet
greeting: The greeting word to use (default: "Hello")
punctuation: The punctuation to end with (default: "!")

Returns:
A greeting message
"""
return f"{greeting}, {name}{punctuation}"


def calculate(expression: str, precision: int = 2) -> str:
"""Calculate a mathematical expression.

Args:
expression: The mathematical expression to evaluate
precision: Number of decimal places (default: 2)

Returns:
The calculated result as a string
"""
return f"Result: {eval(expression):.{precision}f}"


class TestDefaultValues:
"""Test cases for default parameter values in tool definitions."""

def test_extract_defaults_from_callable(self) -> None:
"""Test extracting default values from a callable function."""
tool_def = ToolFunctionDef.from_callable(greet)

assert tool_def.name == "greet"
# Check that defaults are converted to inline format
assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"}
assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"}
assert tool_def.parameters["name"] is str # No default, just type

def test_manual_inline_defaults(self) -> None:
"""Test manually specifying default values in inline format."""
tool_def = ToolFunctionDef(
name="calculate",
description="Calculate a mathematical expression",
parameters={"expression": str, "precision": {"type": int, "default": 2}},
implementation=calculate,
)

# Check that the inline format is preserved
assert tool_def.parameters["precision"] == {"type": int, "default": 2}
assert tool_def.parameters["expression"] is str # No default, just type

def test_json_schema_with_defaults(self) -> None:
"""Test that JSON schema includes default values."""
tool_def = ToolFunctionDef.from_callable(greet)
params_struct, _ = tool_def._to_llm_tool_def()

json_schema = _to_json_schema(params_struct)

# Check that default values are included in the schema
assert json_schema["properties"]["greeting"]["default"] == "Hello"
assert json_schema["properties"]["punctuation"]["default"] == "!"
assert "default" not in json_schema["properties"]["name"]

def test_dict_based_definition(self) -> None:
"""Test dictionary-based tool definition with inline defaults."""
dict_tool: ToolFunctionDefDict = {
"name": "format_text",
"description": "Format text with specified style",
"parameters": {
"text": str,
"style": {"type": str, "default": "normal"},
"uppercase": {"type": bool, "default": False},
},
"implementation": lambda text, style="normal", uppercase=False: text.upper()
if uppercase
else text,
}

# This should work without errors
tool_def = ToolFunctionDef(**dict_tool)
assert tool_def.parameters["style"] == {"type": str, "default": "normal"}
assert tool_def.parameters["uppercase"] == {"type": bool, "default": False}
assert tool_def.parameters["text"] is str # No default, just type

def test_no_defaults(self) -> None:
"""Test function with no default values."""

def no_defaults(a: int, b: str) -> str:
"""Function with no default parameters."""
return f"{a}: {b}"

tool_def = ToolFunctionDef.from_callable(no_defaults)
# All parameters should be simple types without defaults
assert tool_def.parameters["a"] is int
assert tool_def.parameters["b"] is str

params_struct, _ = tool_def._to_llm_tool_def()
json_schema = _to_json_schema(params_struct)

# No default values should be present
assert "default" not in json_schema["properties"]["a"]
assert "default" not in json_schema["properties"]["b"]

def test_mixed_defaults(self) -> None:
"""Test function with some parameters having defaults and others not."""

def mixed_defaults(
required: str, optional1: int = 42, optional2: bool = True
) -> str:
"""Function with mixed required and optional parameters."""
return f"{required}: {optional1}, {optional2}"

tool_def = ToolFunctionDef.from_callable(mixed_defaults)
# Check inline format for parameters with defaults
assert tool_def.parameters["optional1"] == {"type": int, "default": 42}
assert tool_def.parameters["optional2"] == {"type": bool, "default": True}
assert tool_def.parameters["required"] is str # No default, just type

params_struct, _ = tool_def._to_llm_tool_def()
json_schema = _to_json_schema(params_struct)

# Check that default values are correctly included in schema
assert json_schema["properties"]["optional1"]["default"] == 42
assert json_schema["properties"]["optional2"]["default"] is True
assert "default" not in json_schema["properties"]["required"]

def test_extract_type_and_default_method(self) -> None:
"""Test the _extract_type_and_default helper method."""

# Test simple type
param_type, default = ToolFunctionDef._extract_type_and_default(str)
assert param_type is str
assert default is _NO_DEFAULT

# Test inline format with missing type key
with pytest.raises(TypeError, match="Missing 'type' key"):
param_type, default = ToolFunctionDef._extract_type_and_default(
{"default": 42} # type: ignore[arg-type]
)

# Test inline format with no default
param_type, default = ToolFunctionDef._extract_type_and_default({"type": int})
assert param_type is int
assert default is _NO_DEFAULT

# Test inline format with default
param_type, default = ToolFunctionDef._extract_type_and_default(
{"type": int, "default": 42}
)
assert param_type is int
assert default == 42

# Test complex default
param_type, default = ToolFunctionDef._extract_type_and_default(
{"type": list, "default": [1, 2, 3]}
)
assert param_type is list
assert default == [1, 2, 3]

def test_msgspec_auto_defaults(self) -> None:
"""msgspec automatically reflects default values in the JSON schema."""
TestStruct = defstruct(
"TestStruct",
[
("name", str),
("age", int, 18),
("active", bool, True),
],
kw_only=True,
)

schema = _to_json_schema(TestStruct)
properties = schema.get("properties", {})
required = schema.get("required", [])

assert "name" in properties and "default" not in properties["name"]
assert properties["age"].get("default") == 18
assert properties["active"].get("default") is True
assert "name" in required and "age" not in required and "active" not in required