diff --git a/packages/toolbox-core/integration.cloudbuild.yaml b/packages/toolbox-core/integration.cloudbuild.yaml index 61fd6afd..89132be9 100644 --- a/packages/toolbox-core/integration.cloudbuild.yaml +++ b/packages/toolbox-core/integration.cloudbuild.yaml @@ -43,4 +43,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.7.0' + _TOOLBOX_VERSION: '0.8.0' diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index e2071ba2..6606ef93 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -25,31 +25,39 @@ class ParameterSchema(BaseModel): name: str type: str + required: bool = True description: str authSources: Optional[list[str]] = None items: Optional["ParameterSchema"] = None def __get_type(self) -> Type: + base_type: Type if self.type == "string": - return str + base_type = str elif self.type == "integer": - return int + base_type = int elif self.type == "float": - return float + base_type = float elif self.type == "boolean": - return bool + base_type = bool elif self.type == "array": if self.items is None: raise Exception("Unexpected value: type is 'list' but items is None") - return list[self.items.__get_type()] # type: ignore + base_type = list[self.items.__get_type()] # type: ignore + else: + raise ValueError(f"Unsupported schema type: {self.type}") - raise ValueError(f"Unsupported schema type: {self.type}") + if not self.required: + return Optional[base_type] # type: ignore + + return base_type def to_param(self) -> Parameter: return Parameter( self.name, Parameter.POSITIONAL_OR_KEYWORD, annotation=self.__get_type(), + default=Parameter.empty if self.required else None, ) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48a31dab..d9ec2883 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import itertools +from collections import OrderedDict from inspect import Signature from types import MappingProxyType from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union @@ -89,7 +91,13 @@ def __init__( self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) - inspect_type_params = [param.to_param() for param in self.__params] + # Separate parameters into required (no default) and optional (with + # default) to prevent the "non-default argument follows default + # argument" error when creating the function signature. + required_params = (p for p in self.__params if p.required) + optional_params = (p for p in self.__params if not p.required) + ordered_params = itertools.chain(required_params, optional_params) + inspect_type_params = [param.to_param() for param in ordered_params] # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name @@ -268,7 +276,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) - all_args.apply_defaults() # Include default values if not provided + + # The payload will only contain arguments explicitly provided by the user. + # Optional arguments not provided by the user will not be in the payload. payload = all_args.arguments # Perform argument type validations using pydantic @@ -278,6 +288,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: for param, value in self.__bound_parameters.items(): payload[param] = await resolve_value(value) + # Remove None values to prevent server-side type errors. The Toolbox + # server requires specific types for each parameter and will raise an + # error if it receives a None value, which it cannot convert. + payload = OrderedDict({k: v for k, v in payload.items() if v is not None}) + # create headers for auth services headers = {} for client_header_name, client_header_val in self.__client_headers.items(): diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 615a23ec..08a87a45 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -38,9 +38,8 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - return docstring docstring += "\n\nArgs:" for p in params: - docstring += ( - f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" - ) + annotation = p.to_param().annotation + docstring += f"\n {p.name} ({getattr(annotation, '__name__', str(annotation))}): {p.description}" return docstring @@ -111,11 +110,20 @@ def params_to_pydantic_model( """Converts the given parameters to a Pydantic BaseModel class.""" field_definitions = {} for field in params: + + # Determine the default value based on the 'required' flag. + # '...' (Ellipsis) signifies a required field in Pydantic. + # 'None' makes the field optional with a default value of None. + default_value = ... if field.required else None + field_definitions[field.name] = cast( Any, ( field.to_param().annotation, - Field(description=field.description), + Field( + description=field.description, + default=default_value, + ), ), ) return create_model(tool_name, **field_definitions) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 8920bc3b..52f0ba56 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from inspect import Parameter, signature +from typing import Optional + import pytest import pytest_asyncio from pydantic import ValidationError @@ -64,7 +68,7 @@ async def test_load_toolset_specific( async def test_load_toolset_default(self, toolbox: ToolboxClient): """Load the default toolset, i.e. all tools.""" toolset = await toolbox.load_toolset() - assert len(toolset) == 5 + assert len(toolset) == 6 tool_names = {tool.__name__ for tool in toolset} expected_tools = [ "get-row-by-content-auth", @@ -72,6 +76,7 @@ async def test_load_toolset_default(self, toolbox: ToolboxClient): "get-row-by-id-auth", "get-row-by-id", "get-n-rows", + "search-rows", ] assert tool_names == set(expected_tools) @@ -217,3 +222,160 @@ async def test_run_tool_param_auth_no_field( match="no field named row_data in claims", ): await tool() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestOptionalParams: + """ + End-to-end tests for tools with optional parameters. + """ + + async def test_tool_signature_is_correct(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with optional params.""" + tool = await toolbox.load_tool("search-rows") + sig = signature(tool) + + assert "email" in sig.parameters + assert "data" in sig.parameters + assert "id" in sig.parameters + + # The required parameter should have no default + assert sig.parameters["email"].default is Parameter.empty + assert sig.parameters["email"].annotation is str + + # The optional parameter should have a default of None + assert sig.parameters["data"].default is None + assert sig.parameters["data"].annotation is Optional[str] + + # The optional parameter should have a default of None + assert sig.parameters["id"].default is None + assert sig.parameters["id"].annotation is Optional[int] + + async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient): + """Invoke a tool providing only the required parameter.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=1) + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(TypeError, match="missing a required argument: 'email'"): + await tool(id=5, data="row5") + + async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(ValidationError, match="email"): + await tool(email=None, id=5, data="row5") + + async def test_run_tool_with_all_default_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=0, data="row2") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_all_valid_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_different_email(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different email.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="anubhavdhawan@google.com", id=3, data="row3") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_data(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row4") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_id(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=4, data="row3") + assert isinstance(response, str) + assert response == "null" diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index a70fa3fe..b7650792 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -14,6 +14,7 @@ from inspect import Parameter +from typing import Optional import pytest @@ -106,3 +107,66 @@ def test_parameter_schema_unsupported_type_error(): with pytest.raises(ValueError, match=expected_error_msg): schema.to_param() + + +def test_parameter_schema_string_optional(): + """Tests an optional ParameterSchema with type 'string'.""" + schema = ParameterSchema( + name="nickname", + type="string", + description="An optional nickname", + required=False, + ) + expected_type = Optional[str] + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "nickname" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + assert param.default is None + + +def test_parameter_schema_required_by_default(): + """Tests that a parameter is required by default.""" + # 'required' is not specified, so it should default to True. + schema = ParameterSchema(name="id", type="integer", description="A required ID") + expected_type = int + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "id" + assert param.annotation == expected_type + assert param.default == Parameter.empty + + +def test_parameter_schema_array_optional(): + """Tests an optional ParameterSchema with type 'array'.""" + item_schema = ParameterSchema(name="", type="integer", description="") + schema = ParameterSchema( + name="optional_scores", + type="array", + description="An optional list of scores", + items=item_schema, + required=False, + ) + expected_type = Optional[list[int]] + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "optional_scores" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + assert param.default is None diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index c07f44cb..b3ddd7c3 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -34,6 +34,7 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock: param_mock = Mock(spec=ParameterSchema) param_mock.name = name param_mock.description = description + param_mock.required = True mock_param_info = Mock() mock_param_info.annotation = annotation diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index b5ae5510..0deb3a94 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -44,4 +44,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.7.0' + _TOOLBOX_VERSION: '0.8.0' diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 12002717..64371ead 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -75,13 +75,14 @@ async def test_aload_toolset_specific( async def test_aload_toolset_all(self, toolbox): toolset = await toolbox.aload_toolset() - assert len(toolset) == 5 + assert len(toolset) == 6 tool_names = [ "get-n-rows", "get-row-by-id", "get-row-by-id-auth", "get-row-by-email-auth", "get-row-by-content-auth", + "search-rows", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ @@ -220,13 +221,14 @@ def test_load_toolset_specific( def test_aload_toolset_all(self, toolbox): toolset = toolbox.load_toolset() - assert len(toolset) == 5 + assert len(toolset) == 6 tool_names = [ "get-n-rows", "get-row-by-id", "get-row-by-id-auth", "get-row-by-email-auth", "get-row-by-content-auth", + "search-rows", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ diff --git a/packages/toolbox-llamaindex/integration.cloudbuild.yaml b/packages/toolbox-llamaindex/integration.cloudbuild.yaml index 6e15f8a1..9b0b4e5d 100644 --- a/packages/toolbox-llamaindex/integration.cloudbuild.yaml +++ b/packages/toolbox-llamaindex/integration.cloudbuild.yaml @@ -44,4 +44,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.7.0' + _TOOLBOX_VERSION: '0.8.0' diff --git a/packages/toolbox-llamaindex/tests/test_e2e.py b/packages/toolbox-llamaindex/tests/test_e2e.py index 5f389b86..580059a7 100644 --- a/packages/toolbox-llamaindex/tests/test_e2e.py +++ b/packages/toolbox-llamaindex/tests/test_e2e.py @@ -75,13 +75,14 @@ async def test_aload_toolset_specific( async def test_aload_toolset_all(self, toolbox): toolset = await toolbox.aload_toolset() - assert len(toolset) == 5 + assert len(toolset) == 6 tool_names = [ "get-n-rows", "get-row-by-id", "get-row-by-id-auth", "get-row-by-email-auth", "get-row-by-content-auth", + "search-rows", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ @@ -220,13 +221,14 @@ def test_load_toolset_specific( def test_aload_toolset_all(self, toolbox): toolset = toolbox.load_toolset() - assert len(toolset) == 5 + assert len(toolset) == 6 tool_names = [ "get-n-rows", "get-row-by-id", "get-row-by-id-auth", "get-row-by-email-auth", "get-row-by-content-auth", + "search-rows", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__