|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | + |
14 | 15 | import pytest
|
15 | 16 | import pytest_asyncio
|
16 | 17 | from pydantic import ValidationError
|
| 18 | +from inspect import signature, Parameter |
| 19 | +from typing import Optional |
17 | 20 |
|
18 | 21 | from toolbox_core.client import ToolboxClient
|
19 | 22 | from toolbox_core.tool import ToolboxTool
|
@@ -217,3 +220,56 @@ async def test_run_tool_param_auth_no_field(
|
217 | 220 | match="no field named row_data in claims",
|
218 | 221 | ):
|
219 | 222 | await tool()
|
| 223 | + |
| 224 | +@pytest.mark.asyncio |
| 225 | +@pytest.mark.usefixtures("optional_param_server") |
| 226 | +class TestOptionalParams: |
| 227 | + """ |
| 228 | + End-to-end tests for tools with optional parameters. |
| 229 | + """ |
| 230 | + |
| 231 | + async def test_tool_signature_is_correct(self, optional_toolbox: ToolboxClient): |
| 232 | + """Verify the client correctly constructs the signature for a tool with optional params.""" |
| 233 | + tool = await optional_toolbox.load_tool("search-rows") |
| 234 | + sig = signature(tool) |
| 235 | + |
| 236 | + assert "query" in sig.parameters |
| 237 | + assert "limit" in sig.parameters |
| 238 | + |
| 239 | + # The required parameter should have no default |
| 240 | + assert sig.parameters["query"].default is Parameter.empty |
| 241 | + assert sig.parameters["query"].annotation is str |
| 242 | + |
| 243 | + # The optional parameter should have a default of None |
| 244 | + assert sig.parameters["limit"].default is None |
| 245 | + assert sig.parameters["limit"].annotation is Optional[int] |
| 246 | + |
| 247 | + async def test_run_tool_with_optional_param_omitted( |
| 248 | + self, optional_toolbox: ToolboxClient |
| 249 | + ): |
| 250 | + """Invoke a tool providing only the required parameter.""" |
| 251 | + tool = await optional_toolbox.load_tool("search-rows") |
| 252 | + |
| 253 | + response = await tool(query="test query") |
| 254 | + assert isinstance(response, str) |
| 255 | + assert 'query="test query"' in response |
| 256 | + assert "limit" not in response |
| 257 | + |
| 258 | + async def test_run_tool_with_optional_param_provided( |
| 259 | + self, optional_toolbox: ToolboxClient |
| 260 | + ): |
| 261 | + """Invoke a tool providing both required and optional parameters.""" |
| 262 | + tool = await optional_toolbox.load_tool("search-rows") |
| 263 | + |
| 264 | + response = await tool(query="test query", limit=10) |
| 265 | + assert isinstance(response, str) |
| 266 | + assert 'query="test query"' in response |
| 267 | + assert "limit=10" in response |
| 268 | + |
| 269 | + async def test_run_tool_with_missing_required_param( |
| 270 | + self, optional_toolbox: ToolboxClient |
| 271 | + ): |
| 272 | + """Invoke a tool without its required parameter.""" |
| 273 | + tool = await optional_toolbox.load_tool("search-rows") |
| 274 | + with pytest.raises(TypeError, match="missing a required argument: 'query'"): |
| 275 | + await tool(limit=5) |
0 commit comments