Skip to content
64 changes: 59 additions & 5 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import asyncio
import copy
import inspect
import json
import uuid
import warnings
Expand Down Expand Up @@ -1089,12 +1090,22 @@ def __init__(
super().__init__(model_key, params, on_load_progress)


# Add new type definitions for inline parameter format
from typing import TypeVar
from typing_extensions import NotRequired

T = TypeVar('T')

class ToolParamDefDict(TypedDict, Generic[T]):
type: type[T]
default: NotRequired[T]

class ToolFunctionDefDict(TypedDict):
"""SDK input format to specify an LLM tool call and its implementation (as a dict)."""

name: str
description: str
parameters: Mapping[str, Any]
parameters: Mapping[str, type[Any] | ToolParamDefDict[Any]]
implementation: Callable[..., Any]


Expand All @@ -1104,12 +1115,34 @@ class ToolFunctionDef:

name: str
description: str
parameters: Mapping[str, Any]
parameters: Mapping[str, type[Any] | ToolParamDefDict[Any]]
implementation: Callable[..., Any]

def _extract_type_and_default(self, param_name: str, param_value: type[Any] | ToolParamDefDict[Any]) -> tuple[type[Any], Any | None]:
"""Extract type and default value from parameter definition."""
if isinstance(param_value, dict) and "type" in param_value:
# Inline format: {"type": type, "default": value}
param_type = param_value["type"]
default_value = param_value.get("default")
return param_type, default_value
else:
# Simple format: just the type
return param_value, None

def _to_llm_tool_def(self) -> tuple[type[Struct], LlmTool]:
params_struct_name = f"{self.name.capitalize()}Parameters"
params_struct = defstruct(params_struct_name, self.parameters.items())

# Build fields list with defaults
fields: list[tuple[str, Any] | tuple[str, Any, Any]] = []
for param_name, param_value in self.parameters.items():
param_type, default_value = self._extract_type_and_default(param_name, param_value)

if default_value is not None:
fields.append((param_name, param_type, default_value))
else:
fields.append((param_name, param_type))

params_struct = defstruct(params_struct_name, fields, kw_only=True)
return params_struct, LlmTool._from_api_dict(
{
"type": "function",
Expand Down Expand Up @@ -1160,8 +1193,27 @@ def from_callable(
) from exc
# Tool definitions only annotate the input parameters, not the return type
parameters.pop("return", None)

# Extract default values from function signature and convert to inline format
try:
sig = inspect.signature(f)
for param_name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty:
# Convert to inline format: {"type": type, "default": value}
original_type = parameters[param_name]
parameters[param_name] = {
"type": original_type,
"default": param.default
}
except Exception as exc:
# If we can't extract defaults, continue without them
pass

return cls(
name=name, description=description, parameters=parameters, implementation=f
name=name,
description=description,
parameters=parameters,
implementation=f
)


Expand Down Expand Up @@ -1601,7 +1653,9 @@ def parse_tools(
elif callable(tool):
tool_def = ToolFunctionDef.from_callable(tool)
else:
tool_def = ToolFunctionDef(**tool)
# Handle dictionary-based tool definition
tool_dict = cast(ToolFunctionDefDict, tool)
tool_def = ToolFunctionDef(**tool_dict)
if tool_def.name in client_tool_map:
raise LMStudioValueError(
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"
Expand Down
4 changes: 4 additions & 0 deletions src/lmstudio/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def _to_json_schema(cls: type, *, omit: Sequence[str] = ()) -> DictSchema:
for field in omit:
named_schema.pop(field, None)
json_schema.update(named_schema)

# msgspec automatically handles default values in the generated JSON schema
# when they are properly defined in the Struct fields

return json_schema


Expand Down
183 changes: 183 additions & 0 deletions tests/test_default_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Tests for default parameter values in tool definitions."""

import pytest
from typing import Any
from msgspec import defstruct

from src.lmstudio.json_api import ToolFunctionDef, ToolFunctionDefDict
from src.lmstudio.schemas import _to_json_schema


def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str:
"""Greet someone with a customizable message.

Args:
name: The name of the person to greet
greeting: The greeting word to use (default: "Hello")
punctuation: The punctuation to end with (default: "!")

Returns:
A greeting message
"""
return f"{greeting}, {name}{punctuation}"


def calculate(expression: str, precision: int = 2) -> str:
"""Calculate a mathematical expression.

Args:
expression: The mathematical expression to evaluate
precision: Number of decimal places (default: 2)

Returns:
The calculated result as a string
"""
return f"Result: {eval(expression):.{precision}f}"


class TestDefaultValues:
"""Test cases for default parameter values in tool definitions."""

def test_extract_defaults_from_callable(self):
"""Test extracting default values from a callable function."""
tool_def = ToolFunctionDef.from_callable(greet)

assert tool_def.name == "greet"
# Check that defaults are converted to inline format
assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"}
assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"}
assert tool_def.parameters["name"] == str # No default, just type

def test_manual_inline_defaults(self):
"""Test manually specifying default values in inline format."""
tool_def = ToolFunctionDef(
name="calculate",
description="Calculate a mathematical expression",
parameters={
"expression": str,
"precision": {"type": int, "default": 2}
},
implementation=calculate
)

# Check that the inline format is preserved
assert tool_def.parameters["precision"] == {"type": int, "default": 2}
assert tool_def.parameters["expression"] == str # No default, just type

def test_json_schema_with_defaults(self):
"""Test that JSON schema includes default values."""
tool_def = ToolFunctionDef.from_callable(greet)
params_struct, _ = tool_def._to_llm_tool_def()

json_schema = _to_json_schema(params_struct)

# Check that default values are included in the schema
assert json_schema["properties"]["greeting"]["default"] == "Hello"
assert json_schema["properties"]["punctuation"]["default"] == "!"
assert "default" not in json_schema["properties"]["name"]

def test_dict_based_definition(self):
"""Test dictionary-based tool definition with inline defaults."""
dict_tool: ToolFunctionDefDict = {
"name": "format_text",
"description": "Format text with specified style",
"parameters": {
"text": str,
"style": {"type": str, "default": "normal"},
"uppercase": {"type": bool, "default": False}
},
"implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text
}

# This should work without errors
tool_def = ToolFunctionDef(**dict_tool)
assert tool_def.parameters["style"] == {"type": str, "default": "normal"}
assert tool_def.parameters["uppercase"] == {"type": bool, "default": False}
assert tool_def.parameters["text"] == str # No default, just type

def test_no_defaults(self):
"""Test function with no default values."""
def no_defaults(a: int, b: str) -> str:
"""Function with no default parameters."""
return f"{a}: {b}"

tool_def = ToolFunctionDef.from_callable(no_defaults)
# All parameters should be simple types without defaults
assert tool_def.parameters["a"] == int
assert tool_def.parameters["b"] == str

params_struct, _ = tool_def._to_llm_tool_def()
json_schema = _to_json_schema(params_struct)

# No default values should be present
assert "default" not in json_schema["properties"]["a"]
assert "default" not in json_schema["properties"]["b"]

def test_mixed_defaults(self):
"""Test function with some parameters having defaults and others not."""
def mixed_defaults(required: str, optional1: int = 42, optional2: bool = True) -> str:
"""Function with mixed required and optional parameters."""
return f"{required}: {optional1}, {optional2}"

tool_def = ToolFunctionDef.from_callable(mixed_defaults)
# Check inline format for parameters with defaults
assert tool_def.parameters["optional1"] == {"type": int, "default": 42}
assert tool_def.parameters["optional2"] == {"type": bool, "default": True}
assert tool_def.parameters["required"] == str # No default, just type

params_struct, _ = tool_def._to_llm_tool_def()
json_schema = _to_json_schema(params_struct)

# Check that default values are correctly included in schema
assert json_schema["properties"]["optional1"]["default"] == 42
assert json_schema["properties"]["optional2"]["default"] is True
assert "default" not in json_schema["properties"]["required"]

def test_extract_type_and_default_method(self):
"""Test the _extract_type_and_default helper method."""
tool_def = ToolFunctionDef(
name="test",
description="Test tool",
parameters={
"simple": str,
"with_default": {"type": int, "default": 42},
"complex_default": {"type": list, "default": [1, 2, 3]}
},
implementation=lambda x, y, z: None
)

# Test simple type
param_type, default = tool_def._extract_type_and_default("simple", str)
assert param_type == str
assert default is None

# Test inline format with default
param_type, default = tool_def._extract_type_and_default("with_default", {"type": int, "default": 42})
assert param_type == int
assert default == 42

# Test complex default
param_type, default = tool_def._extract_type_and_default("complex_default", {"type": list, "default": [1, 2, 3]})
assert param_type == list
assert default == [1, 2, 3]

def test_msgspec_auto_defaults(self):
"""msgspec automatically reflects default values in the JSON schema."""
TestStruct = defstruct(
"TestStruct",
[
("name", str),
("age", int, 18),
("active", bool, True),
],
kw_only=True,
)

schema = _to_json_schema(TestStruct)
properties = schema.get("properties", {})
required = schema.get("required", [])

assert "name" in properties and "default" not in properties["name"]
assert properties["age"].get("default") == 18
assert properties["active"].get("default") is True
assert "name" in required and "age" not in required and "active" not in required