Skip to content

Commit 3f24e40

Browse files
authored
Fix $defs being discarded in input schema of transformed tool
Fix $defs being discarded in input schema of transformed tool
2 parents 94b1eb9 + 0114f0a commit 3f24e40

File tree

2 files changed

+256
-33
lines changed

2 files changed

+256
-33
lines changed

src/fastmcp/tools/tool_transform.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,12 +841,30 @@ def _merge_schema_with_precedence(
841841
if "default" in param_schema:
842842
final_required.discard(param_name)
843843

844-
return {
844+
# Merge $defs from both schemas, with override taking precedence
845+
merged_defs = base_schema.get("$defs", {}).copy()
846+
override_defs = override_schema.get("$defs", {})
847+
848+
for def_name, def_schema in override_defs.items():
849+
if def_name in merged_defs:
850+
base_def = merged_defs[def_name].copy()
851+
base_def.update(def_schema)
852+
merged_defs[def_name] = base_def
853+
else:
854+
merged_defs[def_name] = def_schema.copy()
855+
856+
result = {
845857
"type": "object",
846858
"properties": merged_props,
847859
"required": list(final_required),
848860
}
849861

862+
if merged_defs:
863+
result["$defs"] = merged_defs
864+
result = compress_schema(result, prune_defs=True)
865+
866+
return result
867+
850868
@staticmethod
851869
def _function_has_kwargs(fn: Callable[..., Any]) -> bool:
852870
"""Check if function accepts **kwargs.

tests/tools/test_tool_transform.py

Lines changed: 237 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from dirty_equals import IsList
7+
from inline_snapshot import snapshot
78
from mcp.types import TextContent
89
from pydantic import BaseModel, Field, TypeAdapter
910
from typing_extensions import TypedDict
@@ -1055,38 +1056,6 @@ def add(x: int, y: int = 10) -> int:
10551056
await client.call_tool("new_add", {"x": 1, "y": 2})
10561057

10571058

1058-
def test_arg_transform_examples_in_schema(add_tool):
1059-
# Simple example
1060-
new_tool = Tool.from_tool(
1061-
add_tool,
1062-
transform_args={
1063-
"old_x": ArgTransform(examples=[1, 2, 3]),
1064-
},
1065-
)
1066-
prop = get_property(new_tool, "old_x")
1067-
assert prop["examples"] == [1, 2, 3]
1068-
1069-
# Nested example (e.g., for array type)
1070-
new_tool2 = Tool.from_tool(
1071-
add_tool,
1072-
transform_args={
1073-
"old_x": ArgTransform(examples=[["a", "b"], ["c", "d"]]),
1074-
},
1075-
)
1076-
prop2 = get_property(new_tool2, "old_x")
1077-
assert prop2["examples"] == [["a", "b"], ["c", "d"]]
1078-
1079-
# If not set, should not be present
1080-
new_tool3 = Tool.from_tool(
1081-
add_tool,
1082-
transform_args={
1083-
"old_x": ArgTransform(),
1084-
},
1085-
)
1086-
prop3 = get_property(new_tool3, "old_x")
1087-
assert "examples" not in prop3
1088-
1089-
10901059
class TestTransformToolOutputSchema:
10911060
"""Test output schema handling in transformed tools."""
10921061

@@ -1502,3 +1471,239 @@ def test_tool_transform_config_removes_meta(sample_tool):
15021471
config = ToolTransformConfig(name="config_tool", meta=None)
15031472
transformed = config.apply(sample_tool)
15041473
assert transformed.meta is None
1474+
1475+
1476+
class TestInputSchema:
1477+
"""Test schema definition handling and reference finding."""
1478+
1479+
def test_arg_transform_examples_in_schema(self, add_tool: Tool):
1480+
# Simple example
1481+
new_tool = Tool.from_tool(
1482+
add_tool,
1483+
transform_args={
1484+
"old_x": ArgTransform(examples=[1, 2, 3]),
1485+
},
1486+
)
1487+
prop = get_property(new_tool, "old_x")
1488+
assert prop["examples"] == [1, 2, 3]
1489+
1490+
# Nested example (e.g., for array type)
1491+
new_tool2 = Tool.from_tool(
1492+
add_tool,
1493+
transform_args={
1494+
"old_x": ArgTransform(examples=[["a", "b"], ["c", "d"]]),
1495+
},
1496+
)
1497+
prop2 = get_property(new_tool2, "old_x")
1498+
assert prop2["examples"] == [["a", "b"], ["c", "d"]]
1499+
1500+
# If not set, should not be present
1501+
new_tool3 = Tool.from_tool(
1502+
add_tool,
1503+
transform_args={
1504+
"old_x": ArgTransform(),
1505+
},
1506+
)
1507+
prop3 = get_property(new_tool3, "old_x")
1508+
assert "examples" not in prop3
1509+
1510+
def test_merge_schema_with_defs_precedence(self):
1511+
"""Test _merge_schema_with_precedence merges $defs correctly."""
1512+
base_schema = {
1513+
"type": "object",
1514+
"properties": {"field1": {"$ref": "#/$defs/BaseType"}},
1515+
"$defs": {
1516+
"BaseType": {"type": "string", "description": "base"},
1517+
"SharedType": {"type": "integer", "minimum": 0},
1518+
},
1519+
}
1520+
1521+
override_schema = {
1522+
"type": "object",
1523+
"properties": {"field2": {"$ref": "#/$defs/OverrideType"}},
1524+
"$defs": {
1525+
"OverrideType": {"type": "boolean"},
1526+
"SharedType": {"type": "integer", "minimum": 10}, # Override
1527+
},
1528+
}
1529+
1530+
transformed_tool_schema = TransformedTool._merge_schema_with_precedence(
1531+
base_schema, override_schema
1532+
)
1533+
1534+
# SharedType should no longer be present on the schema
1535+
assert "SharedType" not in transformed_tool_schema["$defs"]
1536+
1537+
assert transformed_tool_schema == snapshot(
1538+
{
1539+
"type": "object",
1540+
"properties": {
1541+
"field1": {"$ref": "#/$defs/BaseType"},
1542+
"field2": {"$ref": "#/$defs/OverrideType"},
1543+
},
1544+
"required": [],
1545+
"$defs": {
1546+
"BaseType": {"type": "string", "description": "base"},
1547+
"OverrideType": {"type": "boolean"},
1548+
},
1549+
}
1550+
)
1551+
1552+
def test_transform_tool_with_complex_defs_pruning(self):
1553+
"""Test that tool transformation properly prunes unused $defs."""
1554+
1555+
class UsedType(BaseModel):
1556+
value: str
1557+
1558+
class UnusedType(BaseModel):
1559+
other: int
1560+
1561+
@Tool.from_function
1562+
def complex_tool(
1563+
used_param: UsedType, unused_param: UnusedType | None = None
1564+
) -> str:
1565+
return used_param.value
1566+
1567+
# Transform to hide unused_param
1568+
transformed_tool: TransformedTool = Tool.from_tool(
1569+
complex_tool, transform_args={"unused_param": ArgTransform(hide=True)}
1570+
)
1571+
1572+
assert "UnusedType" not in transformed_tool.parameters["$defs"]
1573+
1574+
assert transformed_tool.parameters == snapshot(
1575+
{
1576+
"type": "object",
1577+
"properties": {
1578+
"used_param": {"$ref": "#/$defs/UsedType", "title": "Used Param"}
1579+
},
1580+
"required": ["used_param"],
1581+
"$defs": {
1582+
"UsedType": {
1583+
"properties": {"value": {"title": "Value", "type": "string"}},
1584+
"required": ["value"],
1585+
"title": "UsedType",
1586+
"type": "object",
1587+
}
1588+
},
1589+
}
1590+
)
1591+
1592+
def test_transform_with_custom_function_preserves_needed_defs(self):
1593+
"""Test that custom transform functions preserve necessary $defs."""
1594+
1595+
class InputType(BaseModel):
1596+
data: str
1597+
1598+
class OutputType(BaseModel):
1599+
result: str
1600+
1601+
@Tool.from_function
1602+
def base_tool(input_data: InputType) -> OutputType:
1603+
return OutputType(result=input_data.data.upper())
1604+
1605+
async def transform_function(renamed_input: InputType):
1606+
return await forward(renamed_input=renamed_input)
1607+
1608+
# Transform with custom function and argument rename
1609+
transformed = Tool.from_tool(
1610+
base_tool,
1611+
transform_fn=transform_function,
1612+
transform_args={"input_data": ArgTransform(name="renamed_input")},
1613+
)
1614+
1615+
assert transformed.parameters == snapshot(
1616+
{
1617+
"type": "object",
1618+
"properties": {
1619+
"renamed_input": {
1620+
"$ref": "#/$defs/InputType",
1621+
"title": "Input Data",
1622+
}
1623+
},
1624+
"required": ["renamed_input"],
1625+
"$defs": {
1626+
"InputType": {
1627+
"properties": {"data": {"title": "Data", "type": "string"}},
1628+
"required": ["data"],
1629+
"title": "InputType",
1630+
"type": "object",
1631+
}
1632+
},
1633+
}
1634+
)
1635+
1636+
def test_chained_transforms_preserve_correct_defs(self):
1637+
"""Test that chained transformations preserve correct $defs."""
1638+
1639+
class TypeA(BaseModel):
1640+
a: str
1641+
1642+
class TypeB(BaseModel):
1643+
b: int
1644+
1645+
class TypeC(BaseModel):
1646+
c: bool
1647+
1648+
@Tool.from_function
1649+
def base_tool(param_a: TypeA, param_b: TypeB, param_c: TypeC) -> str:
1650+
return f"{param_a.a}-{param_b.b}-{param_c.c}"
1651+
1652+
# First transform: hide param_c
1653+
transform1 = Tool.from_tool(
1654+
base_tool,
1655+
transform_args={"param_c": ArgTransform(hide=True, default=TypeC(c=True))},
1656+
)
1657+
1658+
assert transform1.parameters == snapshot(
1659+
{
1660+
"type": "object",
1661+
"properties": {
1662+
"param_a": {"$ref": "#/$defs/TypeA", "title": "Param A"},
1663+
"param_b": {"$ref": "#/$defs/TypeB", "title": "Param B"},
1664+
},
1665+
"required": IsList("param_b", "param_a", check_order=False),
1666+
"$defs": {
1667+
"TypeA": {
1668+
"properties": {"a": {"title": "A", "type": "string"}},
1669+
"required": ["a"],
1670+
"title": "TypeA",
1671+
"type": "object",
1672+
},
1673+
"TypeB": {
1674+
"properties": {"b": {"title": "B", "type": "integer"}},
1675+
"required": ["b"],
1676+
"title": "TypeB",
1677+
"type": "object",
1678+
},
1679+
},
1680+
}
1681+
)
1682+
1683+
assert "TypeA" in transform1.parameters["$defs"]
1684+
1685+
# Second transform: hide param_b
1686+
transform2 = Tool.from_tool(
1687+
transform1,
1688+
transform_args={"param_b": ArgTransform(hide=True, default=TypeB(b=42))},
1689+
)
1690+
1691+
assert "TypeB" not in transform2.parameters["$defs"]
1692+
1693+
assert transform2.parameters == snapshot(
1694+
{
1695+
"type": "object",
1696+
"properties": {
1697+
"param_a": {"$ref": "#/$defs/TypeA", "title": "Param A"}
1698+
},
1699+
"required": ["param_a"],
1700+
"$defs": {
1701+
"TypeA": {
1702+
"properties": {"a": {"title": "A", "type": "string"}},
1703+
"required": ["a"],
1704+
"title": "TypeA",
1705+
"type": "object",
1706+
}
1707+
},
1708+
}
1709+
)

0 commit comments

Comments
 (0)