|
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | from typing import Any |
| 5 | +from msgspec import defstruct |
5 | 6 |
|
6 | 7 | from src.lmstudio.json_api import ToolFunctionDef, ToolFunctionDefDict |
7 | 8 | from src.lmstudio.schemas import _to_json_schema |
@@ -29,113 +30,154 @@ def calculate(expression: str, precision: int = 2) -> str: |
29 | 30 | precision: Number of decimal places (default: 2) |
30 | 31 | |
31 | 32 | Returns: |
32 | | - The calculated result |
| 33 | + The calculated result as a string |
33 | 34 | """ |
34 | 35 | return f"Result: {eval(expression):.{precision}f}" |
35 | 36 |
|
36 | 37 |
|
37 | 38 | class TestDefaultValues: |
38 | | - """Test default parameter value functionality.""" |
39 | | - |
| 39 | + """Test cases for default parameter values in tool definitions.""" |
| 40 | + |
40 | 41 | def test_extract_defaults_from_callable(self): |
41 | | - """Test extracting default values from function signature.""" |
| 42 | + """Test extracting default values from a callable function.""" |
42 | 43 | tool_def = ToolFunctionDef.from_callable(greet) |
43 | 44 |
|
44 | 45 | assert tool_def.name == "greet" |
45 | | - assert tool_def.parameter_defaults == { |
46 | | - "greeting": "Hello", |
47 | | - "punctuation": "!" |
48 | | - } |
49 | | - assert "name" not in tool_def.parameter_defaults |
| 46 | + # Check that defaults are converted to inline format |
| 47 | + assert tool_def.parameters["greeting"] == {"type": str, "default": "Hello"} |
| 48 | + assert tool_def.parameters["punctuation"] == {"type": str, "default": "!"} |
| 49 | + assert tool_def.parameters["name"] == str # No default, just type |
50 | 50 |
|
51 | | - def test_manual_defaults(self): |
52 | | - """Test manually specifying default values.""" |
| 51 | + def test_manual_inline_defaults(self): |
| 52 | + """Test manually specifying default values in inline format.""" |
53 | 53 | tool_def = ToolFunctionDef( |
54 | 54 | name="calculate", |
55 | 55 | description="Calculate a mathematical expression", |
56 | | - parameters={"expression": str, "precision": int}, |
57 | | - implementation=calculate, |
58 | | - parameter_defaults={"precision": 2} |
| 56 | + parameters={ |
| 57 | + "expression": str, |
| 58 | + "precision": {"type": int, "default": 2} |
| 59 | + }, |
| 60 | + implementation=calculate |
59 | 61 | ) |
60 | 62 |
|
61 | | - assert tool_def.parameter_defaults == {"precision": 2} |
62 | | - assert "expression" not in tool_def.parameter_defaults |
| 63 | + # Check that the inline format is preserved |
| 64 | + assert tool_def.parameters["precision"] == {"type": int, "default": 2} |
| 65 | + assert tool_def.parameters["expression"] == str # No default, just type |
63 | 66 |
|
64 | 67 | def test_json_schema_with_defaults(self): |
65 | | - """Test that JSON Schema includes default values.""" |
| 68 | + """Test that JSON schema includes default values.""" |
66 | 69 | tool_def = ToolFunctionDef.from_callable(greet) |
67 | 70 | params_struct, _ = tool_def._to_llm_tool_def() |
68 | | - json_schema = _to_json_schema(params_struct) |
69 | | - |
70 | | - properties = json_schema["properties"] |
71 | | - |
72 | | - # name should not have a default (required parameter) |
73 | | - assert "name" in properties |
74 | | - assert "default" not in properties["name"] |
75 | 71 |
|
76 | | - # greeting should have default "Hello" |
77 | | - assert "greeting" in properties |
78 | | - assert properties["greeting"]["default"] == "Hello" |
79 | | - |
80 | | - # punctuation should have default "!" |
81 | | - assert "punctuation" in properties |
82 | | - assert properties["punctuation"]["default"] == "!" |
| 72 | + json_schema = _to_json_schema(params_struct) |
83 | 73 |
|
84 | | - # Only name should be required |
85 | | - assert json_schema["required"] == ["name"] |
| 74 | + # Check that default values are included in the schema |
| 75 | + assert json_schema["properties"]["greeting"]["default"] == "Hello" |
| 76 | + assert json_schema["properties"]["punctuation"]["default"] == "!" |
| 77 | + assert "default" not in json_schema["properties"]["name"] |
86 | 78 |
|
87 | 79 | def test_dict_based_definition(self): |
88 | | - """Test dictionary-based tool definition with defaults.""" |
| 80 | + """Test dictionary-based tool definition with inline defaults.""" |
89 | 81 | dict_tool: ToolFunctionDefDict = { |
90 | 82 | "name": "format_text", |
91 | 83 | "description": "Format text with specified style", |
92 | | - "parameters": {"text": str, "style": str, "uppercase": bool}, |
93 | | - "implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text, |
94 | | - "parameter_defaults": {"style": "normal", "uppercase": False} |
| 84 | + "parameters": { |
| 85 | + "text": str, |
| 86 | + "style": {"type": str, "default": "normal"}, |
| 87 | + "uppercase": {"type": bool, "default": False} |
| 88 | + }, |
| 89 | + "implementation": lambda text, style="normal", uppercase=False: text.upper() if uppercase else text |
95 | 90 | } |
96 | 91 |
|
97 | 92 | # This should work without errors |
98 | 93 | tool_def = ToolFunctionDef(**dict_tool) |
99 | | - assert tool_def.parameter_defaults == {"style": "normal", "uppercase": False} |
| 94 | + assert tool_def.parameters["style"] == {"type": str, "default": "normal"} |
| 95 | + assert tool_def.parameters["uppercase"] == {"type": bool, "default": False} |
| 96 | + assert tool_def.parameters["text"] == str # No default, just type |
100 | 97 |
|
101 | 98 | def test_no_defaults(self): |
102 | 99 | """Test function with no default values.""" |
103 | 100 | def no_defaults(a: int, b: str) -> str: |
104 | 101 | """Function with no default parameters.""" |
105 | | - return f"{a}{b}" |
| 102 | + return f"{a}: {b}" |
106 | 103 |
|
107 | 104 | tool_def = ToolFunctionDef.from_callable(no_defaults) |
108 | | - assert tool_def.parameter_defaults == {} |
| 105 | + # All parameters should be simple types without defaults |
| 106 | + assert tool_def.parameters["a"] == int |
| 107 | + assert tool_def.parameters["b"] == str |
109 | 108 |
|
110 | 109 | params_struct, _ = tool_def._to_llm_tool_def() |
111 | 110 | json_schema = _to_json_schema(params_struct) |
112 | 111 |
|
113 | | - # All parameters should be required |
114 | | - assert json_schema["required"] == ["a", "b"] |
115 | | - |
116 | | - # No properties should have defaults |
117 | | - for prop in json_schema["properties"].values(): |
118 | | - assert "default" not in prop |
| 112 | + # No default values should be present |
| 113 | + assert "default" not in json_schema["properties"]["a"] |
| 114 | + assert "default" not in json_schema["properties"]["b"] |
119 | 115 |
|
120 | 116 | def test_mixed_defaults(self): |
121 | | - """Test function with some parameters having defaults.""" |
| 117 | + """Test function with some parameters having defaults and others not.""" |
122 | 118 | def mixed_defaults(required: str, optional1: int = 42, optional2: bool = True) -> str: |
123 | 119 | """Function with mixed required and optional parameters.""" |
124 | | - return f"{required}{optional1}{optional2}" |
| 120 | + return f"{required}: {optional1}, {optional2}" |
125 | 121 |
|
126 | 122 | tool_def = ToolFunctionDef.from_callable(mixed_defaults) |
127 | | - assert tool_def.parameter_defaults == { |
128 | | - "optional1": 42, |
129 | | - "optional2": True |
130 | | - } |
| 123 | + # Check inline format for parameters with defaults |
| 124 | + assert tool_def.parameters["optional1"] == {"type": int, "default": 42} |
| 125 | + assert tool_def.parameters["optional2"] == {"type": bool, "default": True} |
| 126 | + assert tool_def.parameters["required"] == str # No default, just type |
131 | 127 |
|
132 | 128 | params_struct, _ = tool_def._to_llm_tool_def() |
133 | 129 | json_schema = _to_json_schema(params_struct) |
134 | 130 |
|
135 | | - # Only required should be in required list |
136 | | - assert json_schema["required"] == ["required"] |
137 | | - |
138 | | - # Check defaults |
| 131 | + # Check that default values are correctly included in schema |
139 | 132 | assert json_schema["properties"]["optional1"]["default"] == 42 |
140 | 133 | assert json_schema["properties"]["optional2"]["default"] is True |
141 | 134 | assert "default" not in json_schema["properties"]["required"] |
| 135 | + |
| 136 | + def test_extract_type_and_default_method(self): |
| 137 | + """Test the _extract_type_and_default helper method.""" |
| 138 | + tool_def = ToolFunctionDef( |
| 139 | + name="test", |
| 140 | + description="Test tool", |
| 141 | + parameters={ |
| 142 | + "simple": str, |
| 143 | + "with_default": {"type": int, "default": 42}, |
| 144 | + "complex_default": {"type": list, "default": [1, 2, 3]} |
| 145 | + }, |
| 146 | + implementation=lambda x, y, z: None |
| 147 | + ) |
| 148 | + |
| 149 | + # Test simple type |
| 150 | + param_type, default = tool_def._extract_type_and_default("simple", str) |
| 151 | + assert param_type == str |
| 152 | + assert default is None |
| 153 | + |
| 154 | + # Test inline format with default |
| 155 | + param_type, default = tool_def._extract_type_and_default("with_default", {"type": int, "default": 42}) |
| 156 | + assert param_type == int |
| 157 | + assert default == 42 |
| 158 | + |
| 159 | + # Test complex default |
| 160 | + param_type, default = tool_def._extract_type_and_default("complex_default", {"type": list, "default": [1, 2, 3]}) |
| 161 | + assert param_type == list |
| 162 | + assert default == [1, 2, 3] |
| 163 | + |
| 164 | + def test_msgspec_auto_defaults(self): |
| 165 | + """msgspec automatically reflects default values in the JSON schema.""" |
| 166 | + TestStruct = defstruct( |
| 167 | + "TestStruct", |
| 168 | + [ |
| 169 | + ("name", str), |
| 170 | + ("age", int, 18), |
| 171 | + ("active", bool, True), |
| 172 | + ], |
| 173 | + kw_only=True, |
| 174 | + ) |
| 175 | + |
| 176 | + schema = _to_json_schema(TestStruct) |
| 177 | + properties = schema.get("properties", {}) |
| 178 | + required = schema.get("required", []) |
| 179 | + |
| 180 | + assert "name" in properties and "default" not in properties["name"] |
| 181 | + assert properties["age"].get("default") == 18 |
| 182 | + assert properties["active"].get("default") is True |
| 183 | + assert "name" in required and "age" not in required and "active" not in required |
0 commit comments