|
7 | 7 | import controlflow |
8 | 8 | from controlflow.agents.agent import Agent |
9 | 9 | from controlflow.llm.messages import ToolMessage |
10 | | -from controlflow.tools.tools import ( |
11 | | - Tool, |
12 | | - handle_tool_call, |
13 | | - tool, |
14 | | -) |
| 10 | +from controlflow.tools.tools import Tool, handle_tool_call, tool |
15 | 11 |
|
16 | 12 |
|
17 | 13 | @pytest.mark.parametrize("style", ["decorator", "class"]) |
@@ -170,6 +166,77 @@ def add(a: int, b: float) -> float: |
170 | 166 | elif style == "decorator": |
171 | 167 | tool(add) |
172 | 168 |
|
| 169 | + def test_custom_parameters(self, style): |
| 170 | + """Test that custom parameters override generated ones.""" |
| 171 | + |
| 172 | + def add(a: int, b: float): |
| 173 | + return a + b |
| 174 | + |
| 175 | + custom_params = { |
| 176 | + "type": "object", |
| 177 | + "properties": { |
| 178 | + "x": {"type": "number", "description": "Custom parameter"}, |
| 179 | + "y": {"type": "string"}, |
| 180 | + }, |
| 181 | + "required": ["x"], |
| 182 | + } |
| 183 | + |
| 184 | + if style == "class": |
| 185 | + tool_obj = Tool.from_function(add, parameters=custom_params) |
| 186 | + elif style == "decorator": |
| 187 | + tool_obj = tool(add, parameters=custom_params) |
| 188 | + |
| 189 | + assert tool_obj.parameters == custom_params |
| 190 | + assert "a" not in tool_obj.parameters["properties"] |
| 191 | + assert "b" not in tool_obj.parameters["properties"] |
| 192 | + assert ( |
| 193 | + tool_obj.parameters["properties"]["x"]["description"] == "Custom parameter" |
| 194 | + ) |
| 195 | + |
| 196 | + def test_custom_parameters_with_annotations(self, style): |
| 197 | + """Test that annotations still work with custom parameters if param names match.""" |
| 198 | + |
| 199 | + def process(x: Annotated[float, "The x value"], y: str): |
| 200 | + return x |
| 201 | + |
| 202 | + custom_params = { |
| 203 | + "type": "object", |
| 204 | + "properties": {"x": {"type": "number"}, "y": {"type": "string"}}, |
| 205 | + "required": ["x"], |
| 206 | + } |
| 207 | + |
| 208 | + if style == "class": |
| 209 | + tool_obj = Tool.from_function(process, parameters=custom_params) |
| 210 | + elif style == "decorator": |
| 211 | + tool_obj = tool(process, parameters=custom_params) |
| 212 | + |
| 213 | + assert tool_obj.parameters["properties"]["x"]["description"] == "The x value" |
| 214 | + assert "description" not in tool_obj.parameters["properties"]["y"] |
| 215 | + |
| 216 | + def test_custom_parameters_ignore_descriptions(self, style): |
| 217 | + """Test that include_param_descriptions=False works with custom parameters.""" |
| 218 | + |
| 219 | + def process(x: Annotated[float, "The x value"], y: str): |
| 220 | + return x |
| 221 | + |
| 222 | + custom_params = { |
| 223 | + "type": "object", |
| 224 | + "properties": {"x": {"type": "number"}, "y": {"type": "string"}}, |
| 225 | + "required": ["x"], |
| 226 | + } |
| 227 | + |
| 228 | + if style == "class": |
| 229 | + tool_obj = Tool.from_function( |
| 230 | + process, parameters=custom_params, include_param_descriptions=False |
| 231 | + ) |
| 232 | + elif style == "decorator": |
| 233 | + tool_obj = tool( |
| 234 | + process, parameters=custom_params, include_param_descriptions=False |
| 235 | + ) |
| 236 | + |
| 237 | + assert "description" not in tool_obj.parameters["properties"]["x"] |
| 238 | + assert "description" not in tool_obj.parameters["properties"]["y"] |
| 239 | + |
173 | 240 |
|
174 | 241 | class TestToolFunctions: |
175 | 242 | def test_non_serializable_return_value(self): |
|
0 commit comments