diff --git a/packages/toolbox-core/integration.cloudbuild.yaml b/packages/toolbox-core/integration.cloudbuild.yaml index 89132be9..b966a688 100644 --- a/packages/toolbox-core/integration.cloudbuild.yaml +++ b/packages/toolbox-core/integration.cloudbuild.yaml @@ -43,4 +43,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.8.0' + _TOOLBOX_VERSION: '0.10.0' diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index b101c9b8..24e2bcbc 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -13,7 +13,7 @@ # limitations under the License. from inspect import Parameter -from typing import Optional, Type +from typing import Any, Optional, Type, Union from pydantic import BaseModel @@ -29,6 +29,7 @@ class ParameterSchema(BaseModel): description: str authSources: Optional[list[str]] = None items: Optional["ParameterSchema"] = None + AdditionalProperties: Optional[Union[bool, "ParameterSchema"]] = None def __get_type(self) -> Type: base_type: Type @@ -44,6 +45,12 @@ def __get_type(self) -> Type: if self.items is None: raise ValueError("Unexpected value: type is 'array' but items is None") base_type = list[self.items.__get_type()] # type: ignore + elif self.type == "object": + if isinstance(self.AdditionalProperties, ParameterSchema): + value_type = self.AdditionalProperties.__get_type() + base_type = dict[str, value_type] # type: ignore + else: + base_type = dict[str, Any] else: raise ValueError(f"Unsupported schema type: {self.type}") diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 52f0ba56..b1275486 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -13,7 +13,7 @@ # limitations under the License. from inspect import Parameter, signature -from typing import Optional +from typing import Any, Optional import pytest import pytest_asyncio @@ -68,7 +68,7 @@ async def test_load_toolset_specific( async def test_load_toolset_default(self, toolbox: ToolboxClient): """Load the default toolset, i.e. all tools.""" toolset = await toolbox.load_toolset() - assert len(toolset) == 6 + assert len(toolset) == 7 tool_names = {tool.__name__ for tool in toolset} expected_tools = [ "get-row-by-content-auth", @@ -77,6 +77,7 @@ async def test_load_toolset_default(self, toolbox: ToolboxClient): "get-row-by-id", "get-n-rows", "search-rows", + "process-data", ] assert tool_names == set(expected_tools) @@ -379,3 +380,66 @@ async def test_run_tool_with_different_id(self, toolbox: ToolboxClient): response = await tool(email="twishabansal@google.com", id=4, data="row3") assert isinstance(response, str) assert response == "null" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestMapParams: + """ + End-to-end tests for tools with map parameters. + """ + + async def test_tool_signature_with_map_params(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with map params.""" + tool = await toolbox.load_tool("process-data") + sig = signature(tool) + + assert "execution_context" in sig.parameters + assert sig.parameters["execution_context"].annotation == dict[str, Any] + assert sig.parameters["execution_context"].default is Parameter.empty + + assert "user_scores" in sig.parameters + assert sig.parameters["user_scores"].annotation == dict[str, int] + assert sig.parameters["user_scores"].default is Parameter.empty + + assert "feature_flags" in sig.parameters + assert sig.parameters["feature_flags"].annotation == Optional[dict[str, bool]] + assert sig.parameters["feature_flags"].default is None + + async def test_run_tool_with_map_params(self, toolbox: ToolboxClient): + """Invoke a tool with valid map parameters.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "prod", "id": 1234, "user": 1234.5}, + user_scores={"user1": 100, "user2": 200}, + feature_flags={"new_feature": True}, + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"prod","id":1234,"user":1234.5}' in response + assert '"user_scores":{"user1":100,"user2":200}' in response + assert '"feature_flags":{"new_feature":true}' in response + + async def test_run_tool_with_optional_map_param_omitted( + self, toolbox: ToolboxClient + ): + """Invoke a tool without the optional map parameter.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "dev"}, user_scores={"user3": 300} + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"dev"}' in response + assert '"user_scores":{"user3":300}' in response + assert '"feature_flags":null' in response + + async def test_run_tool_with_wrong_map_value_type(self, toolbox: ToolboxClient): + """Invoke a tool with a map parameter having the wrong value type.""" + tool = await toolbox.load_tool("process-data") + + with pytest.raises(ValidationError): + await tool( + execution_context={"env": "staging"}, + user_scores={"user4": "not-an-integer"}, + ) diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index c2b4096d..55deb7c9 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -14,7 +14,7 @@ from inspect import Parameter -from typing import Optional +from typing import Any, Optional import pytest @@ -170,3 +170,111 @@ def test_parameter_schema_array_optional(): assert param.annotation == expected_type assert param.kind == Parameter.POSITIONAL_OR_KEYWORD assert param.default is None + + +def test_parameter_schema_map_generic(): + """Tests ParameterSchema with a generic 'object' type.""" + schema = ParameterSchema( + name="metadata", + type="object", + description="Some metadata", + AdditionalProperties=True, + ) + expected_type = dict[str, Any] + assert schema._ParameterSchema__get_type() == expected_type + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "metadata" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + + +def test_parameter_schema_map_typed_string(): + """Tests ParameterSchema with a typed 'object' type (string values).""" + schema = ParameterSchema( + name="headers", + type="object", + description="HTTP headers", + AdditionalProperties=ParameterSchema(name="", type="string", description=""), + ) + expected_type = dict[str, str] + assert schema._ParameterSchema__get_type() == expected_type + + param = schema.to_param() + assert param.annotation == expected_type + + +def test_parameter_schema_map_typed_integer(): + """Tests ParameterSchema with a typed 'object' type (integer values).""" + schema = ParameterSchema( + name="user_scores", + type="object", + description="User scores", + AdditionalProperties=ParameterSchema(name="", type="integer", description=""), + ) + expected_type = dict[str, int] + assert schema._ParameterSchema__get_type() == expected_type + param = schema.to_param() + assert param.annotation == expected_type + + +def test_parameter_schema_map_typed_float(): + """Tests ParameterSchema with a typed 'object' type (float values).""" + schema = ParameterSchema( + name="item_prices", + type="object", + description="Item prices", + AdditionalProperties=ParameterSchema(name="", type="float", description=""), + ) + expected_type = dict[str, float] + assert schema._ParameterSchema__get_type() == expected_type + param = schema.to_param() + assert param.annotation == expected_type + + +def test_parameter_schema_map_typed_boolean(): + """Tests ParameterSchema with a typed 'object' type (boolean values).""" + schema = ParameterSchema( + name="feature_flags", + type="object", + description="Feature flags", + AdditionalProperties=ParameterSchema(name="", type="boolean", description=""), + ) + expected_type = dict[str, bool] + assert schema._ParameterSchema__get_type() == expected_type + param = schema.to_param() + assert param.annotation == expected_type + + +def test_parameter_schema_map_optional(): + """Tests an optional ParameterSchema with a 'object' type.""" + schema = ParameterSchema( + name="optional_metadata", + type="object", + description="Optional metadata", + required=False, + AdditionalProperties=True, + ) + expected_type = Optional[dict[str, Any]] + assert schema._ParameterSchema__get_type() == expected_type + param = schema.to_param() + assert param.annotation == expected_type + assert param.default is None + + +def test_parameter_schema_map_unsupported_value_type_error(): + """Tests that an unsupported map valueType raises ValueError.""" + unsupported_type = "custom_object" + schema = ParameterSchema( + name="custom_data", + type="object", + description="Custom data map", + valueType=unsupported_type, + AdditionalProperties=ParameterSchema( + name="", type=unsupported_type, description="" + ), + ) + expected_error_msg = f"Unsupported schema type: {unsupported_type}" + with pytest.raises(ValueError, match=expected_error_msg): + schema._ParameterSchema__get_type() diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index 0deb3a94..d38a072f 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -44,4 +44,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.8.0' + _TOOLBOX_VERSION: '0.10.0' diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 64371ead..ea750c5b 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -75,7 +75,7 @@ async def test_aload_toolset_specific( async def test_aload_toolset_all(self, toolbox): toolset = await toolbox.aload_toolset() - assert len(toolset) == 6 + assert len(toolset) == 7 tool_names = [ "get-n-rows", "get-row-by-id", @@ -83,6 +83,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-email-auth", "get-row-by-content-auth", "search-rows", + "process-data", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ @@ -221,7 +222,7 @@ def test_load_toolset_specific( def test_aload_toolset_all(self, toolbox): toolset = toolbox.load_toolset() - assert len(toolset) == 6 + assert len(toolset) == 7 tool_names = [ "get-n-rows", "get-row-by-id", @@ -229,6 +230,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-email-auth", "get-row-by-content-auth", "search-rows", + "process-data", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ diff --git a/packages/toolbox-llamaindex/integration.cloudbuild.yaml b/packages/toolbox-llamaindex/integration.cloudbuild.yaml index 9b0b4e5d..7854a596 100644 --- a/packages/toolbox-llamaindex/integration.cloudbuild.yaml +++ b/packages/toolbox-llamaindex/integration.cloudbuild.yaml @@ -44,4 +44,4 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.8.0' + _TOOLBOX_VERSION: '0.10.0' diff --git a/packages/toolbox-llamaindex/tests/test_e2e.py b/packages/toolbox-llamaindex/tests/test_e2e.py index 580059a7..15973fb3 100644 --- a/packages/toolbox-llamaindex/tests/test_e2e.py +++ b/packages/toolbox-llamaindex/tests/test_e2e.py @@ -75,7 +75,7 @@ async def test_aload_toolset_specific( async def test_aload_toolset_all(self, toolbox): toolset = await toolbox.aload_toolset() - assert len(toolset) == 6 + assert len(toolset) == 7 tool_names = [ "get-n-rows", "get-row-by-id", @@ -83,6 +83,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-email-auth", "get-row-by-content-auth", "search-rows", + "process-data", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__ @@ -221,7 +222,7 @@ def test_load_toolset_specific( def test_aload_toolset_all(self, toolbox): toolset = toolbox.load_toolset() - assert len(toolset) == 6 + assert len(toolset) == 7 tool_names = [ "get-n-rows", "get-row-by-id", @@ -229,6 +230,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-email-auth", "get-row-by-content-auth", "search-rows", + "process-data", ] for tool in toolset: name = tool._ToolboxTool__core_tool.__name__