|
| 1 | +"""Tests for default parameter values in tool definitions.""" |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from msgspec import defstruct |
| 6 | + |
| 7 | +from lmstudio.json_api import _NO_DEFAULT, ToolFunctionDef, ToolFunctionDefDict |
| 8 | +from lmstudio.schemas import _to_json_schema |
| 9 | + |
| 10 | + |
| 11 | +def greet(name: str, greeting: str = "Hello", punctuation: str = "!") -> str: |
| 12 | + """Greet someone with a customizable message. |
| 13 | +
|
| 14 | + Args: |
| 15 | + name: The name of the person to greet |
| 16 | + greeting: The greeting word to use (default: "Hello") |
| 17 | + punctuation: The punctuation to end with (default: "!") |
| 18 | +
|
| 19 | + Returns: |
| 20 | + A greeting message |
| 21 | + """ |
| 22 | + return f"{greeting}, {name}{punctuation}" |
| 23 | + |
| 24 | + |
| 25 | +def calculate(expression: str, precision: int = 2) -> str: |
| 26 | + """Calculate a mathematical expression. |
| 27 | +
|
| 28 | + Args: |
| 29 | + expression: The mathematical expression to evaluate |
| 30 | + precision: Number of decimal places (default: 2) |
| 31 | +
|
| 32 | + Returns: |
| 33 | + The calculated result as a string |
| 34 | + """ |
| 35 | + return f"Result: {eval(expression):.{precision}f}" |
| 36 | + |
| 37 | + |
| 38 | +class TestDefaultValues: |
| 39 | + """Test cases for default parameter values in tool definitions.""" |
| 40 | + |
| 41 | + def test_extract_defaults_from_callable(self) -> None: |
| 42 | + """Test extracting default values from a callable function.""" |
| 43 | + tool_def = ToolFunctionDef.from_callable(greet) |
| 44 | + |
| 45 | + assert tool_def.name == "greet" |
| 46 | + # Check that defaults are converted to inline format |
| 47 | + assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"} |
| 48 | + assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"} |
| 49 | + assert tool_def.parameters["name"] is str # No default, just type |
| 50 | + |
| 51 | + def test_manual_inline_defaults(self) -> None: |
| 52 | + """Test manually specifying default values in inline format.""" |
| 53 | + tool_def = ToolFunctionDef( |
| 54 | + name="calculate", |
| 55 | + description="Calculate a mathematical expression", |
| 56 | + parameters={"expression": str, "precision": {"type": int, "default": 2}}, |
| 57 | + implementation=calculate, |
| 58 | + ) |
| 59 | + |
| 60 | + # Check that the inline format is preserved |
| 61 | + assert tool_def.parameters["precision"] == {"type": int, "default": 2} |
| 62 | + assert tool_def.parameters["expression"] is str # No default, just type |
| 63 | + |
| 64 | + def test_json_schema_with_defaults(self) -> None: |
| 65 | + """Test that JSON schema includes default values.""" |
| 66 | + tool_def = ToolFunctionDef.from_callable(greet) |
| 67 | + params_struct, _ = tool_def._to_llm_tool_def() |
| 68 | + |
| 69 | + json_schema = _to_json_schema(params_struct) |
| 70 | + |
| 71 | + # Check that default values are included in the schema |
| 72 | + assert json_schema["properties"]["greeting"]["default"] == "Hello" |
| 73 | + assert json_schema["properties"]["punctuation"]["default"] == "!" |
| 74 | + assert "default" not in json_schema["properties"]["name"] |
| 75 | + |
| 76 | + def test_dict_based_definition(self) -> None: |
| 77 | + """Test dictionary-based tool definition with inline defaults.""" |
| 78 | + dict_tool: ToolFunctionDefDict = { |
| 79 | + "name": "format_text", |
| 80 | + "description": "Format text with specified style", |
| 81 | + "parameters": { |
| 82 | + "text": str, |
| 83 | + "style": {"type": str, "default": "normal"}, |
| 84 | + "uppercase": {"type": bool, "default": False}, |
| 85 | + }, |
| 86 | + "implementation": lambda text, style="normal", uppercase=False: text.upper() |
| 87 | + if uppercase |
| 88 | + else text, |
| 89 | + } |
| 90 | + |
| 91 | + # This should work without errors |
| 92 | + tool_def = ToolFunctionDef(**dict_tool) |
| 93 | + assert tool_def.parameters["style"] == {"type": str, "default": "normal"} |
| 94 | + assert tool_def.parameters["uppercase"] == {"type": bool, "default": False} |
| 95 | + assert tool_def.parameters["text"] is str # No default, just type |
| 96 | + |
| 97 | + def test_no_defaults(self) -> None: |
| 98 | + """Test function with no default values.""" |
| 99 | + |
| 100 | + def no_defaults(a: int, b: str) -> str: |
| 101 | + """Function with no default parameters.""" |
| 102 | + return f"{a}: {b}" |
| 103 | + |
| 104 | + tool_def = ToolFunctionDef.from_callable(no_defaults) |
| 105 | + # All parameters should be simple types without defaults |
| 106 | + assert tool_def.parameters["a"] is int |
| 107 | + assert tool_def.parameters["b"] is str |
| 108 | + |
| 109 | + params_struct, _ = tool_def._to_llm_tool_def() |
| 110 | + json_schema = _to_json_schema(params_struct) |
| 111 | + |
| 112 | + # No default values should be present |
| 113 | + assert "default" not in json_schema["properties"]["a"] |
| 114 | + assert "default" not in json_schema["properties"]["b"] |
| 115 | + |
| 116 | + def test_mixed_defaults(self) -> None: |
| 117 | + """Test function with some parameters having defaults and others not.""" |
| 118 | + |
| 119 | + def mixed_defaults( |
| 120 | + required: str, optional1: int = 42, optional2: bool = True |
| 121 | + ) -> str: |
| 122 | + """Function with mixed required and optional parameters.""" |
| 123 | + return f"{required}: {optional1}, {optional2}" |
| 124 | + |
| 125 | + tool_def = ToolFunctionDef.from_callable(mixed_defaults) |
| 126 | + # Check inline format for parameters with defaults |
| 127 | + assert tool_def.parameters["optional1"] == {"type": int, "default": 42} |
| 128 | + assert tool_def.parameters["optional2"] == {"type": bool, "default": True} |
| 129 | + assert tool_def.parameters["required"] is str # No default, just type |
| 130 | + |
| 131 | + params_struct, _ = tool_def._to_llm_tool_def() |
| 132 | + json_schema = _to_json_schema(params_struct) |
| 133 | + |
| 134 | + # Check that default values are correctly included in schema |
| 135 | + assert json_schema["properties"]["optional1"]["default"] == 42 |
| 136 | + assert json_schema["properties"]["optional2"]["default"] is True |
| 137 | + assert "default" not in json_schema["properties"]["required"] |
| 138 | + |
| 139 | + def test_extract_type_and_default_method(self) -> None: |
| 140 | + """Test the _extract_type_and_default helper method.""" |
| 141 | + |
| 142 | + # Test simple type |
| 143 | + param_type, default = ToolFunctionDef._extract_type_and_default(str) |
| 144 | + assert param_type is str |
| 145 | + assert default is _NO_DEFAULT |
| 146 | + |
| 147 | + # Test inline format with missing type key |
| 148 | + with pytest.raises(TypeError, match="Missing 'type' key"): |
| 149 | + param_type, default = ToolFunctionDef._extract_type_and_default( |
| 150 | + {"default": 42} # type: ignore[arg-type] |
| 151 | + ) |
| 152 | + |
| 153 | + # Test inline format with no default |
| 154 | + param_type, default = ToolFunctionDef._extract_type_and_default({"type": int}) |
| 155 | + assert param_type is int |
| 156 | + assert default is _NO_DEFAULT |
| 157 | + |
| 158 | + # Test inline format with default |
| 159 | + param_type, default = ToolFunctionDef._extract_type_and_default( |
| 160 | + {"type": int, "default": 42} |
| 161 | + ) |
| 162 | + assert param_type is int |
| 163 | + assert default == 42 |
| 164 | + |
| 165 | + # Test complex default |
| 166 | + param_type, default = ToolFunctionDef._extract_type_and_default( |
| 167 | + {"type": list, "default": [1, 2, 3]} |
| 168 | + ) |
| 169 | + assert param_type is list |
| 170 | + assert default == [1, 2, 3] |
| 171 | + |
| 172 | + def test_msgspec_auto_defaults(self) -> None: |
| 173 | + """msgspec automatically reflects default values in the JSON schema.""" |
| 174 | + TestStruct = defstruct( |
| 175 | + "TestStruct", |
| 176 | + [ |
| 177 | + ("name", str), |
| 178 | + ("age", int, 18), |
| 179 | + ("active", bool, True), |
| 180 | + ], |
| 181 | + kw_only=True, |
| 182 | + ) |
| 183 | + |
| 184 | + schema = _to_json_schema(TestStruct) |
| 185 | + properties = schema.get("properties", {}) |
| 186 | + required = schema.get("required", []) |
| 187 | + |
| 188 | + assert "name" in properties and "default" not in properties["name"] |
| 189 | + assert properties["age"].get("default") == 18 |
| 190 | + assert properties["active"].get("default") is True |
| 191 | + assert "name" in required and "age" not in required and "active" not in required |
0 commit comments