Skip to content

feat(core): Add support for map parameter type #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion packages/toolbox-core/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ options:
logging: CLOUD_LOGGING_ONLY
substitutions:
_VERSION: '3.13'
_TOOLBOX_VERSION: '0.8.0'
_TOOLBOX_VERSION: '0.10.0'
9 changes: 8 additions & 1 deletion packages/toolbox-core/src/toolbox_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from inspect import Parameter
from typing import Optional, Type
from typing import Any, Optional, Type, Union

from pydantic import BaseModel

Expand All @@ -29,6 +29,7 @@ class ParameterSchema(BaseModel):
description: str
authSources: Optional[list[str]] = None
items: Optional["ParameterSchema"] = None
AdditionalProperties: Optional[Union[bool, "ParameterSchema"]] = None

def __get_type(self) -> Type:
base_type: Type
Expand All @@ -44,6 +45,12 @@ def __get_type(self) -> Type:
if self.items is None:
raise ValueError("Unexpected value: type is 'array' but items is None")
base_type = list[self.items.__get_type()] # type: ignore
elif self.type == "object":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should this be a map type parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the manifest the type is returned as a object.

CC: @duwenxin99

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@duwenxin99 are there specific reasons for using different manifest and object types in the server and client?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Unresolving this comment for now)

Copy link

@duwenxin99 duwenxin99 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two reasons:

  1. To keep it consistent as the MCP manifest. MCP uses JSON schema and it doesn't have a "map" type. We are basically using "object" type to achieve the map functionality.
  2. If we decide to add in the "object" type parameter in the future, we can reuse the current "map" code as they are basically the same structure.
    Open to discussion on this @twishabansal , @anubhav756 we can change it for the server if that's easier on the SDK side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just worried that in the config file, the user defines a map. However, during usage with SDK or debugging, the user would see an object. Seems like this might be a little confusing for the user. What do you folks think?

if isinstance(self.AdditionalProperties, ParameterSchema):
value_type = self.AdditionalProperties.__get_type()
base_type = dict[str, value_type] # type: ignore
else:
base_type = dict[str, Any]
else:
raise ValueError(f"Unsupported schema type: {self.type}")

Expand Down
68 changes: 66 additions & 2 deletions packages/toolbox-core/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from inspect import Parameter, signature
from typing import Optional
from typing import Any, Optional

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -68,7 +68,7 @@ async def test_load_toolset_specific(
async def test_load_toolset_default(self, toolbox: ToolboxClient):
"""Load the default toolset, i.e. all tools."""
toolset = await toolbox.load_toolset()
assert len(toolset) == 6
assert len(toolset) == 7
tool_names = {tool.__name__ for tool in toolset}
expected_tools = [
"get-row-by-content-auth",
Expand All @@ -77,6 +77,7 @@ async def test_load_toolset_default(self, toolbox: ToolboxClient):
"get-row-by-id",
"get-n-rows",
"search-rows",
"process-data",
]
assert tool_names == set(expected_tools)

Expand Down Expand Up @@ -379,3 +380,66 @@ async def test_run_tool_with_different_id(self, toolbox: ToolboxClient):
response = await tool(email="[email protected]", id=4, data="row3")
assert isinstance(response, str)
assert response == "null"


@pytest.mark.asyncio
@pytest.mark.usefixtures("toolbox_server")
class TestMapParams:
"""
End-to-end tests for tools with map parameters.
"""

async def test_tool_signature_with_map_params(self, toolbox: ToolboxClient):
"""Verify the client correctly constructs the signature for a tool with map params."""
tool = await toolbox.load_tool("process-data")
sig = signature(tool)

assert "execution_context" in sig.parameters
assert sig.parameters["execution_context"].annotation == dict[str, Any]
assert sig.parameters["execution_context"].default is Parameter.empty

assert "user_scores" in sig.parameters
assert sig.parameters["user_scores"].annotation == dict[str, int]
assert sig.parameters["user_scores"].default is Parameter.empty

assert "feature_flags" in sig.parameters
assert sig.parameters["feature_flags"].annotation == Optional[dict[str, bool]]
assert sig.parameters["feature_flags"].default is None

async def test_run_tool_with_map_params(self, toolbox: ToolboxClient):
"""Invoke a tool with valid map parameters."""
tool = await toolbox.load_tool("process-data")

response = await tool(
execution_context={"env": "prod", "id": 1234, "user": 1234.5},
user_scores={"user1": 100, "user2": 200},
feature_flags={"new_feature": True},
)
assert isinstance(response, str)
assert '"execution_context":{"env":"prod","id":1234,"user":1234.5}' in response
assert '"user_scores":{"user1":100,"user2":200}' in response
assert '"feature_flags":{"new_feature":true}' in response

async def test_run_tool_with_optional_map_param_omitted(
self, toolbox: ToolboxClient
):
"""Invoke a tool without the optional map parameter."""
tool = await toolbox.load_tool("process-data")

response = await tool(
execution_context={"env": "dev"}, user_scores={"user3": 300}
)
assert isinstance(response, str)
assert '"execution_context":{"env":"dev"}' in response
assert '"user_scores":{"user3":300}' in response
assert '"feature_flags":null' in response

async def test_run_tool_with_wrong_map_value_type(self, toolbox: ToolboxClient):
"""Invoke a tool with a map parameter having the wrong value type."""
tool = await toolbox.load_tool("process-data")

with pytest.raises(ValidationError):
await tool(
execution_context={"env": "staging"},
user_scores={"user4": "not-an-integer"},
)
110 changes: 109 additions & 1 deletion packages/toolbox-core/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from inspect import Parameter
from typing import Optional
from typing import Any, Optional

import pytest

Expand Down Expand Up @@ -170,3 +170,111 @@ def test_parameter_schema_array_optional():
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
assert param.default is None


def test_parameter_schema_map_generic():
"""Tests ParameterSchema with a generic 'object' type."""
schema = ParameterSchema(
name="metadata",
type="object",
description="Some metadata",
AdditionalProperties=True,
)
expected_type = dict[str, Any]
assert schema._ParameterSchema__get_type() == expected_type

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "metadata"
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD


def test_parameter_schema_map_typed_string():
"""Tests ParameterSchema with a typed 'object' type (string values)."""
schema = ParameterSchema(
name="headers",
type="object",
description="HTTP headers",
AdditionalProperties=ParameterSchema(name="", type="string", description=""),
)
expected_type = dict[str, str]
assert schema._ParameterSchema__get_type() == expected_type

param = schema.to_param()
assert param.annotation == expected_type


def test_parameter_schema_map_typed_integer():
"""Tests ParameterSchema with a typed 'object' type (integer values)."""
schema = ParameterSchema(
name="user_scores",
type="object",
description="User scores",
AdditionalProperties=ParameterSchema(name="", type="integer", description=""),
)
expected_type = dict[str, int]
assert schema._ParameterSchema__get_type() == expected_type
param = schema.to_param()
assert param.annotation == expected_type


def test_parameter_schema_map_typed_float():
"""Tests ParameterSchema with a typed 'object' type (float values)."""
schema = ParameterSchema(
name="item_prices",
type="object",
description="Item prices",
AdditionalProperties=ParameterSchema(name="", type="float", description=""),
)
expected_type = dict[str, float]
assert schema._ParameterSchema__get_type() == expected_type
param = schema.to_param()
assert param.annotation == expected_type


def test_parameter_schema_map_typed_boolean():
"""Tests ParameterSchema with a typed 'object' type (boolean values)."""
schema = ParameterSchema(
name="feature_flags",
type="object",
description="Feature flags",
AdditionalProperties=ParameterSchema(name="", type="boolean", description=""),
)
expected_type = dict[str, bool]
assert schema._ParameterSchema__get_type() == expected_type
param = schema.to_param()
assert param.annotation == expected_type


def test_parameter_schema_map_optional():
"""Tests an optional ParameterSchema with a 'object' type."""
schema = ParameterSchema(
name="optional_metadata",
type="object",
description="Optional metadata",
required=False,
AdditionalProperties=True,
)
expected_type = Optional[dict[str, Any]]
assert schema._ParameterSchema__get_type() == expected_type
param = schema.to_param()
assert param.annotation == expected_type
assert param.default is None


def test_parameter_schema_map_unsupported_value_type_error():
"""Tests that an unsupported map valueType raises ValueError."""
unsupported_type = "custom_object"
schema = ParameterSchema(
name="custom_data",
type="object",
description="Custom data map",
valueType=unsupported_type,
AdditionalProperties=ParameterSchema(
name="", type=unsupported_type, description=""
),
)
expected_error_msg = f"Unsupported schema type: {unsupported_type}"
with pytest.raises(ValueError, match=expected_error_msg):
schema._ParameterSchema__get_type()
2 changes: 1 addition & 1 deletion packages/toolbox-langchain/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ options:
logging: CLOUD_LOGGING_ONLY
substitutions:
_VERSION: '3.13'
_TOOLBOX_VERSION: '0.8.0'
_TOOLBOX_VERSION: '0.10.0'
6 changes: 4 additions & 2 deletions packages/toolbox-langchain/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ async def test_aload_toolset_specific(

async def test_aload_toolset_all(self, toolbox):
toolset = await toolbox.aload_toolset()
assert len(toolset) == 6
assert len(toolset) == 7
tool_names = [
"get-n-rows",
"get-row-by-id",
"get-row-by-id-auth",
"get-row-by-email-auth",
"get-row-by-content-auth",
"search-rows",
"process-data",
]
for tool in toolset:
name = tool._ToolboxTool__core_tool.__name__
Expand Down Expand Up @@ -221,14 +222,15 @@ def test_load_toolset_specific(

def test_aload_toolset_all(self, toolbox):
toolset = toolbox.load_toolset()
assert len(toolset) == 6
assert len(toolset) == 7
tool_names = [
"get-n-rows",
"get-row-by-id",
"get-row-by-id-auth",
"get-row-by-email-auth",
"get-row-by-content-auth",
"search-rows",
"process-data",
]
for tool in toolset:
name = tool._ToolboxTool__core_tool.__name__
Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-llamaindex/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ options:
logging: CLOUD_LOGGING_ONLY
substitutions:
_VERSION: '3.13'
_TOOLBOX_VERSION: '0.8.0'
_TOOLBOX_VERSION: '0.10.0'
6 changes: 4 additions & 2 deletions packages/toolbox-llamaindex/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ async def test_aload_toolset_specific(

async def test_aload_toolset_all(self, toolbox):
toolset = await toolbox.aload_toolset()
assert len(toolset) == 6
assert len(toolset) == 7
tool_names = [
"get-n-rows",
"get-row-by-id",
"get-row-by-id-auth",
"get-row-by-email-auth",
"get-row-by-content-auth",
"search-rows",
"process-data",
]
for tool in toolset:
name = tool._ToolboxTool__core_tool.__name__
Expand Down Expand Up @@ -221,14 +222,15 @@ def test_load_toolset_specific(

def test_aload_toolset_all(self, toolbox):
toolset = toolbox.load_toolset()
assert len(toolset) == 6
assert len(toolset) == 7
tool_names = [
"get-n-rows",
"get-row-by-id",
"get-row-by-id-auth",
"get-row-by-email-auth",
"get-row-by-content-auth",
"search-rows",
"process-data",
]
for tool in toolset:
name = tool._ToolboxTool__core_tool.__name__
Expand Down