Skip to content

Commit 4ea0f19

Browse files
committed
add required support to SchemaFromTextExtractor
1 parent 0c56db9 commit 4ea0f19

File tree

3 files changed

+201
-0
lines changed

3 files changed

+201
-0
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,40 @@ def _filter_properties_required_field(
725725

726726
return node_types
727727

728+
def _enforce_required_for_constraint_properties(
729+
self,
730+
node_types: List[Dict[str, Any]],
731+
constraints: List[Dict[str, Any]],
732+
) -> None:
733+
"""Ensure properties with UNIQUENESS constraints are marked as required."""
734+
if not constraints:
735+
return
736+
737+
# Build a lookup for property_names and constraints
738+
constraint_props: Dict[str, set[str]] = {}
739+
for c in constraints:
740+
if c.get("type") == "UNIQUENESS":
741+
label = c.get("node_type")
742+
prop = c.get("property_name")
743+
if label and prop:
744+
constraint_props.setdefault(label, set()).add(prop)
745+
746+
# Skop node_types without constraints
747+
for node_type in node_types:
748+
label = node_type.get("label")
749+
if label not in constraint_props:
750+
continue
751+
752+
props_to_fix = constraint_props[label]
753+
for prop in node_type.get("properties", []):
754+
if isinstance(prop, dict) and prop.get("name") in props_to_fix:
755+
if prop.get("required") is not True:
756+
logging.info(
757+
f"Auto-setting 'required' as True for property '{prop.get('name')}' "
758+
f"on node '{label}' (has UNIQUENESS constraint)."
759+
)
760+
prop["required"] = True
761+
728762
def _clean_json_content(self, content: str) -> str:
729763
content = content.strip()
730764

@@ -815,6 +849,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
815849
extracted_patterns, extracted_node_types, extracted_relationship_types
816850
)
817851

852+
# Enforce required=true for properties with UNIQUENESS constraints
853+
if extracted_constraints:
854+
self._enforce_required_for_constraint_properties(
855+
extracted_node_types, extracted_constraints
856+
)
857+
818858
# Filter out invalid constraints
819859
if extracted_constraints:
820860
extracted_constraints = self._filter_invalid_constraints(

src/neo4j_graphrag/generation/prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class SchemaExtractionTemplate(PromptTemplate):
229229
9.3 Properties that are identifiers, names, or essential characteristics are typically required.
230230
9.4 Properties that are supplementary information (phone numbers, descriptions, metadata) are typically optional.
231231
9.5 When uncertain, default to "required": false.
232+
9.6 If a property has a UNIQUENESS constraint, it MUST be marked as "required": true.
232233
233234
Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
234235
LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME.

tests/unit/experimental/components/test_schema.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,127 @@ def test_filter_properties_required_field_missing(
15671567
assert "required" not in result[0]["properties"][0]
15681568

15691569

1570+
def test_enforce_required_for_constraint_properties_sets_required_true(
1571+
schema_from_text: SchemaFromTextExtractor,
1572+
) -> None:
1573+
node_types: list[dict[str, Any]] = [
1574+
{
1575+
"label": "Person",
1576+
"properties": [
1577+
{"name": "name", "type": "STRING", "required": False},
1578+
{"name": "email", "type": "STRING", "required": False},
1579+
],
1580+
}
1581+
]
1582+
constraints = [
1583+
{"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}
1584+
]
1585+
1586+
schema_from_text._enforce_required_for_constraint_properties(
1587+
node_types, constraints
1588+
)
1589+
1590+
# name should now be required=true
1591+
assert node_types[0]["properties"][0]["required"] is True
1592+
# email should remain required=false
1593+
assert node_types[0]["properties"][1]["required"] is False
1594+
1595+
1596+
def test_enforce_required_for_constraint_properties_already_true(
1597+
schema_from_text: SchemaFromTextExtractor,
1598+
) -> None:
1599+
node_types: list[dict[str, Any]] = [
1600+
{
1601+
"label": "Person",
1602+
"properties": [
1603+
{"name": "name", "type": "STRING", "required": True},
1604+
],
1605+
}
1606+
]
1607+
constraints = [
1608+
{"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}
1609+
]
1610+
1611+
schema_from_text._enforce_required_for_constraint_properties(
1612+
node_types, constraints
1613+
)
1614+
1615+
assert node_types[0]["properties"][0]["required"] is True
1616+
1617+
1618+
def test_enforce_required_for_constraint_properties_missing_required_field(
1619+
schema_from_text: SchemaFromTextExtractor,
1620+
) -> None:
1621+
node_types: list[dict[str, Any]] = [
1622+
{
1623+
"label": "Person",
1624+
"properties": [
1625+
{"name": "name", "type": "STRING"}, # No required field
1626+
],
1627+
}
1628+
]
1629+
constraints = [
1630+
{"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}
1631+
]
1632+
1633+
schema_from_text._enforce_required_for_constraint_properties(
1634+
node_types, constraints
1635+
)
1636+
1637+
assert node_types[0]["properties"][0]["required"] is True
1638+
1639+
1640+
def test_enforce_required_for_constraint_properties_no_constraints(
1641+
schema_from_text: SchemaFromTextExtractor,
1642+
) -> None:
1643+
node_types: list[dict[str, Any]] = [
1644+
{
1645+
"label": "Person",
1646+
"properties": [
1647+
{"name": "name", "type": "STRING", "required": False},
1648+
],
1649+
}
1650+
]
1651+
constraints: list[dict[str, Any]] = []
1652+
1653+
schema_from_text._enforce_required_for_constraint_properties(
1654+
node_types, constraints
1655+
)
1656+
1657+
assert node_types[0]["properties"][0]["required"] is False
1658+
1659+
1660+
def test_enforce_required_for_constraint_properties_skips_unconstrained_nodes(
1661+
schema_from_text: SchemaFromTextExtractor,
1662+
) -> None:
1663+
node_types: list[dict[str, Any]] = [
1664+
{
1665+
"label": "Person",
1666+
"properties": [
1667+
{"name": "name", "type": "STRING", "required": False},
1668+
],
1669+
},
1670+
{
1671+
"label": "Company",
1672+
"properties": [
1673+
{"name": "name", "type": "STRING", "required": False},
1674+
],
1675+
},
1676+
]
1677+
constraints = [
1678+
{"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}
1679+
]
1680+
1681+
schema_from_text._enforce_required_for_constraint_properties(
1682+
node_types, constraints
1683+
)
1684+
1685+
# Person.name should be required=true
1686+
assert node_types[0]["properties"][0]["required"] is True
1687+
# Company.name should remain required=false (no constraint on Company)
1688+
assert node_types[1]["properties"][0]["required"] is False
1689+
1690+
15701691
@pytest.mark.asyncio
15711692
async def test_schema_from_text_with_required_properties(
15721693
schema_from_text: SchemaFromTextExtractor,
@@ -1638,6 +1759,45 @@ async def test_schema_from_text_handles_missing_required_field(
16381759
assert prop.required is False
16391760

16401761

1762+
@pytest.mark.asyncio
1763+
async def test_schema_from_text_enforces_required_for_constrained_properties(
1764+
schema_from_text: SchemaFromTextExtractor,
1765+
mock_llm: AsyncMock,
1766+
) -> None:
1767+
schema_json = """
1768+
{
1769+
"node_types": [
1770+
{
1771+
"label": "Person",
1772+
"properties": [
1773+
{"name": "name", "type": "STRING", "required": false},
1774+
{"name": "email", "type": "STRING", "required": false}
1775+
]
1776+
}
1777+
],
1778+
"relationship_types": [],
1779+
"patterns": [],
1780+
"constraints": [
1781+
{"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}
1782+
]
1783+
}
1784+
"""
1785+
mock_llm.ainvoke.return_value = LLMResponse(content=schema_json)
1786+
1787+
schema = await schema_from_text.run(text="Sample text")
1788+
1789+
person = schema.node_type_from_label("Person")
1790+
assert person is not None
1791+
1792+
name_prop = next((p for p in person.properties if p.name == "name"), None)
1793+
email_prop = next((p for p in person.properties if p.name == "email"), None)
1794+
1795+
# name should be auto-fixed to required=true
1796+
assert name_prop is not None and name_prop.required is True
1797+
# email should remain required=false
1798+
assert email_prop is not None and email_prop.required is False
1799+
1800+
16411801
@pytest.mark.asyncio
16421802
@patch("neo4j_graphrag.experimental.components.schema.get_structured_schema")
16431803
async def test_schema_from_existing_graph(mock_get_structured_schema: Mock) -> None:

0 commit comments

Comments
 (0)