Skip to content

Commit 68f264e

Browse files
committed
Implement _find_referenced_defs method for schema reference resolution and enhance _merge_schema_with_precedence to prune unused $defs. Add comprehensive tests for new functionality in test_tool_transform.py.
1 parent fe8056f commit 68f264e

File tree

2 files changed

+274
-1
lines changed

2 files changed

+274
-1
lines changed

src/fastmcp/tools/tool_transform.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,45 @@ def _apply_single_transform(
787787

788788
return new_name, new_schema, is_required
789789

790+
@staticmethod
791+
def _find_referenced_defs(
792+
schema: dict[str, Any], available_defs: dict[str, Any]
793+
) -> dict[str, Any]:
794+
"""Find all $defs that are actually referenced in the schema.
795+
796+
Args:
797+
schema: The schema to search for references
798+
available_defs: All available definitions to check against
799+
800+
Returns:
801+
Dictionary containing only the referenced definitions
802+
"""
803+
referenced = set()
804+
visited = set() # Track visited definitions to prevent infinite recursion
805+
806+
def find_refs(obj):
807+
if isinstance(obj, dict):
808+
if "$ref" in obj:
809+
ref = obj["$ref"]
810+
if ref.startswith("#/$defs/"):
811+
def_name = ref[8:] # Remove "#/$defs/" prefix
812+
if def_name in available_defs and def_name not in visited:
813+
referenced.add(def_name)
814+
visited.add(def_name) # Mark as visited before recursing
815+
# Recursively check the referenced definition
816+
find_refs(available_defs[def_name])
817+
for value in obj.values():
818+
find_refs(value)
819+
elif isinstance(obj, list):
820+
for item in obj:
821+
find_refs(item)
822+
823+
find_refs(schema)
824+
825+
return {
826+
name: available_defs[name] for name in referenced if name in available_defs
827+
}
828+
790829
@staticmethod
791830
def _merge_schema_with_precedence(
792831
base_schema: dict[str, Any], override_schema: dict[str, Any]
@@ -841,12 +880,25 @@ def _merge_schema_with_precedence(
841880
if "default" in param_schema:
842881
final_required.discard(param_name)
843882

844-
return {
883+
# Merge $defs from both schemas, with override taking precedence
884+
merged_defs = base_schema.get("$defs", {}).copy()
885+
override_defs = override_schema.get("$defs", {})
886+
merged_defs.update(override_defs)
887+
888+
result = {
845889
"type": "object",
846890
"properties": merged_props,
847891
"required": list(final_required),
848892
}
849893

894+
# Only include $defs that are actually referenced in the schema
895+
if merged_defs:
896+
referenced_defs = TransformedTool._find_referenced_defs(result, merged_defs)
897+
if referenced_defs:
898+
result["$defs"] = referenced_defs
899+
900+
return result
901+
850902
@staticmethod
851903
def _function_has_kwargs(fn: Callable[..., Any]) -> bool:
852904
"""Check if function accepts **kwargs.

tests/tools/test_tool_transform.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,3 +1502,224 @@ def test_tool_transform_config_removes_meta(sample_tool):
15021502
config = ToolTransformConfig(name="config_tool", meta=None)
15031503
transformed = config.apply(sample_tool)
15041504
assert transformed.meta is None
1505+
1506+
1507+
# Tests for $defs and _find_referenced_defs functionality
1508+
class TestDefsAndReferences:
1509+
"""Test schema definition handling and reference finding."""
1510+
1511+
def test_find_referenced_defs_simple_reference(self):
1512+
"""Test _find_referenced_defs with a simple reference."""
1513+
schema = {"type": "object", "properties": {"field1": {"$ref": "#/$defs/TypeA"}}}
1514+
available_defs = {
1515+
"TypeA": {"type": "string"},
1516+
"TypeB": {"type": "integer"}, # Not referenced
1517+
}
1518+
1519+
result = TransformedTool._find_referenced_defs(schema, available_defs)
1520+
assert result == {"TypeA": {"type": "string"}}
1521+
assert "TypeB" not in result
1522+
1523+
def test_find_referenced_defs_nested_references(self):
1524+
"""Test _find_referenced_defs with nested references."""
1525+
schema = {"type": "object", "properties": {"field1": {"$ref": "#/$defs/TypeA"}}}
1526+
available_defs = {
1527+
"TypeA": {
1528+
"type": "object",
1529+
"properties": {"nested": {"$ref": "#/$defs/TypeB"}},
1530+
},
1531+
"TypeB": {"type": "string"},
1532+
"TypeC": {"type": "integer"}, # Not referenced
1533+
}
1534+
1535+
result = TransformedTool._find_referenced_defs(schema, available_defs)
1536+
assert result == {
1537+
"TypeA": {
1538+
"type": "object",
1539+
"properties": {"nested": {"$ref": "#/$defs/TypeB"}},
1540+
},
1541+
"TypeB": {"type": "string"},
1542+
}
1543+
assert "TypeC" not in result
1544+
1545+
def test_find_referenced_defs_circular_references(self):
1546+
"""Test _find_referenced_defs handles circular references."""
1547+
schema = {"type": "object", "properties": {"field1": {"$ref": "#/$defs/TypeA"}}}
1548+
available_defs = {
1549+
"TypeA": {
1550+
"type": "object",
1551+
"properties": {"circular": {"$ref": "#/$defs/TypeB"}},
1552+
},
1553+
"TypeB": {
1554+
"type": "object",
1555+
"properties": {"back_ref": {"$ref": "#/$defs/TypeA"}},
1556+
},
1557+
"TypeC": {"type": "string"}, # Not referenced
1558+
}
1559+
1560+
result = TransformedTool._find_referenced_defs(schema, available_defs)
1561+
assert "TypeA" in result
1562+
assert "TypeB" in result
1563+
assert "TypeC" not in result
1564+
1565+
def test_find_referenced_defs_array_references(self):
1566+
"""Test _find_referenced_defs with references in arrays."""
1567+
schema = {
1568+
"type": "object",
1569+
"properties": {
1570+
"field1": {"type": "array", "items": {"$ref": "#/$defs/TypeA"}}
1571+
},
1572+
}
1573+
available_defs = {
1574+
"TypeA": {"type": "string"},
1575+
"TypeB": {"type": "integer"}, # Not referenced
1576+
}
1577+
1578+
result = TransformedTool._find_referenced_defs(schema, available_defs)
1579+
assert result == {"TypeA": {"type": "string"}}
1580+
1581+
def test_find_referenced_defs_no_references(self):
1582+
"""Test _find_referenced_defs with no references."""
1583+
schema = {"type": "object", "properties": {"field1": {"type": "string"}}}
1584+
available_defs = {"TypeA": {"type": "string"}, "TypeB": {"type": "integer"}}
1585+
1586+
result = TransformedTool._find_referenced_defs(schema, available_defs)
1587+
assert result == {}
1588+
1589+
def test_merge_schema_with_defs_precedence(self):
1590+
"""Test _merge_schema_with_precedence merges $defs correctly."""
1591+
base_schema = {
1592+
"type": "object",
1593+
"properties": {"field1": {"$ref": "#/$defs/BaseType"}},
1594+
"$defs": {
1595+
"BaseType": {"type": "string", "description": "base"},
1596+
"SharedType": {"type": "integer", "minimum": 0},
1597+
},
1598+
}
1599+
1600+
override_schema = {
1601+
"type": "object",
1602+
"properties": {"field2": {"$ref": "#/$defs/OverrideType"}},
1603+
"$defs": {
1604+
"OverrideType": {"type": "boolean"},
1605+
"SharedType": {"type": "integer", "minimum": 10}, # Override
1606+
},
1607+
}
1608+
1609+
result = TransformedTool._merge_schema_with_precedence(
1610+
base_schema, override_schema
1611+
)
1612+
1613+
# Should have both field1 and field2
1614+
assert "field1" in result["properties"]
1615+
assert "field2" in result["properties"]
1616+
1617+
# Should only include referenced defs
1618+
defs = result.get("$defs", {})
1619+
assert "BaseType" in defs # Referenced by field1
1620+
assert "OverrideType" in defs # Referenced by field2
1621+
1622+
# SharedType should use override version, but only if referenced
1623+
# Since it's not referenced by either field, it shouldn't be included
1624+
assert "SharedType" not in defs
1625+
1626+
def test_transform_tool_with_complex_defs_pruning(self):
1627+
"""Test that tool transformation properly prunes unused $defs."""
1628+
1629+
class UsedType(BaseModel):
1630+
value: str
1631+
1632+
class UnusedType(BaseModel):
1633+
other: int
1634+
1635+
@Tool.from_function
1636+
def complex_tool(
1637+
used_param: UsedType, unused_param: UnusedType | None = None
1638+
) -> str:
1639+
return used_param.value
1640+
1641+
# Transform to hide unused_param
1642+
transformed = Tool.from_tool(
1643+
complex_tool, transform_args={"unused_param": ArgTransform(hide=True)}
1644+
)
1645+
1646+
# Only UsedType should be in $defs, not UnusedType
1647+
defs = transformed.parameters.get("$defs", {})
1648+
type_names = set(defs.keys())
1649+
1650+
# Should contain UsedType but not UnusedType
1651+
used_type_found = any("UsedType" in name for name in type_names)
1652+
unused_type_found = any("UnusedType" in name for name in type_names)
1653+
1654+
assert used_type_found, f"UsedType not found in defs: {type_names}"
1655+
assert not unused_type_found, f"UnusedType should not be in defs: {type_names}"
1656+
1657+
def test_transform_with_custom_function_preserves_needed_defs(self):
1658+
"""Test that custom transform functions preserve necessary $defs."""
1659+
1660+
class InputType(BaseModel):
1661+
data: str
1662+
1663+
class OutputType(BaseModel):
1664+
result: str
1665+
1666+
@Tool.from_function
1667+
def base_tool(input_data: InputType) -> OutputType:
1668+
return OutputType(result=input_data.data.upper())
1669+
1670+
async def transform_function(renamed_input: InputType):
1671+
return await forward(renamed_input=renamed_input)
1672+
1673+
# Transform with custom function and argument rename
1674+
transformed = Tool.from_tool(
1675+
base_tool,
1676+
transform_fn=transform_function,
1677+
transform_args={"input_data": ArgTransform(name="renamed_input")},
1678+
)
1679+
1680+
# Both InputType and OutputType should be preserved in defs
1681+
defs = transformed.parameters.get("$defs", {})
1682+
type_names = set(defs.keys())
1683+
1684+
input_type_found = any("InputType" in name for name in type_names)
1685+
assert input_type_found, f"InputType not found in defs: {type_names}"
1686+
1687+
def test_chained_transforms_preserve_correct_defs(self):
1688+
"""Test that chained transformations preserve correct $defs."""
1689+
1690+
class TypeA(BaseModel):
1691+
a: str
1692+
1693+
class TypeB(BaseModel):
1694+
b: int
1695+
1696+
class TypeC(BaseModel):
1697+
c: bool
1698+
1699+
@Tool.from_function
1700+
def base_tool(param_a: TypeA, param_b: TypeB, param_c: TypeC) -> str:
1701+
return f"{param_a.a}-{param_b.b}-{param_c.c}"
1702+
1703+
# First transform: hide param_c
1704+
transform1 = Tool.from_tool(
1705+
base_tool,
1706+
transform_args={"param_c": ArgTransform(hide=True, default=TypeC(c=True))},
1707+
)
1708+
1709+
# Second transform: hide param_b
1710+
transform2 = Tool.from_tool(
1711+
transform1,
1712+
transform_args={"param_b": ArgTransform(hide=True, default=TypeB(b=42))},
1713+
)
1714+
1715+
# Final schema should only have TypeA in $defs
1716+
defs = transform2.parameters.get("$defs", {})
1717+
type_names = set(defs.keys())
1718+
1719+
type_a_found = any("TypeA" in name for name in type_names)
1720+
type_b_found = any("TypeB" in name for name in type_names)
1721+
type_c_found = any("TypeC" in name for name in type_names)
1722+
1723+
assert type_a_found, f"TypeA should be in defs: {type_names}"
1724+
assert not type_b_found, f"TypeB should not be in defs: {type_names}"
1725+
assert not type_c_found, f"TypeC should not be in defs: {type_names}"

0 commit comments

Comments
 (0)