diff --git a/CHANGELOG.md b/CHANGELOG.md index cf0076495..f82e76854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph. +- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor. ## 1.2.0 diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index ced91ae90..d8d83ed72 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -11,6 +11,10 @@ from neo4j_graphrag.embeddings import OpenAIEmbeddings from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.experimental.pipeline.types import ( + EntityInputType, + RelationInputType, +) from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.llm.openai_llm import OpenAILLM @@ -21,12 +25,28 @@ # Text to process TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides, -an aristocratic family that rules the planet Caladan.""" +an aristocratic family that rules the planet Caladan, the rainy planet, since 10191.""" # Instantiate Entity and Relation objects. This defines the # entities and relations the LLM will be looking for in the text. -ENTITIES = ["Person", "House", "Planet"] -RELATIONS = ["PARENT_OF", "HEIR_OF", "RULES"] +ENTITIES: list[EntityInputType] = [ + # entities can be defined with a simple label... + "Person", + # ... or with a dict if more details are needed, + # such as a description: + {"label": "House", "description": "Family the person belongs to"}, + # or a list of properties the LLM will try to attach to the entity: + {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, +] +# same thing for relationships: +RELATIONS: list[RelationInputType] = [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons", + }, + {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, +] POTENTIAL_SCHEMA = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 70ab0f0a7..33bf6849f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union import neo4j from pydantic import BaseModel, ConfigDict, Field @@ -42,6 +42,10 @@ from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult +from neo4j_graphrag.experimental.pipeline.types import ( + EntityInputType, + RelationInputType, +) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface @@ -74,8 +78,16 @@ class SimpleKGPipeline: llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. driver (neo4j.Driver): A Neo4j driver instance for database connection. embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks. - entities (Optional[List[str]]): A list of entity labels as strings. - relations (Optional[List[str]]): A list of relation labels as strings. + entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): A list of either: + + - str: entity labels + - dict: following the SchemaEntity schema, ie with label, description and properties keys + + relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): A list of either: + + - str: relation label + - dict: following the SchemaRelation schema, ie with label, description and properties keys + potential_schema (Optional[List[tuple]]): A list of potential schema relationships. from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. @@ -94,8 +106,8 @@ def __init__( llm: LLMInterface, driver: neo4j.Driver, embedder: Embedder, - entities: Optional[List[str]] = None, - relations: Optional[List[str]] = None, + entities: Optional[Sequence[EntityInputType]] = None, + relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, from_pdf: bool = True, text_splitter: Optional[Any] = None, @@ -106,9 +118,9 @@ def __init__( perform_entity_resolution: bool = True, lexical_graph_config: Optional[LexicalGraphConfig] = None, ): - self.entities = [SchemaEntity(label=label) for label in entities or []] - self.relations = [SchemaRelation(label=label) for label in relations or []] - self.potential_schema = potential_schema if potential_schema is not None else [] + self.potential_schema = potential_schema or [] + self.entities = [self.to_schema_entity(e) for e in entities or []] + self.relations = [self.to_schema_relation(r) for r in relations or []] try: on_error_enum = OnError(on_error) @@ -150,6 +162,18 @@ def __init__( self.pipeline = self._build_pipeline() + @staticmethod + def to_schema_entity(entity: EntityInputType) -> SchemaEntity: + if isinstance(entity, dict): + return SchemaEntity.model_validate(entity) + return SchemaEntity(label=entity) + + @staticmethod + def to_schema_relation(relation: RelationInputType) -> SchemaRelation: + if isinstance(relation, dict): + return SchemaRelation.model_validate(relation) + return SchemaRelation(label=relation) + def _build_pipeline(self) -> Pipeline: pipe = Pipeline() diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index e5ec66b9b..ebdf141d7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -14,6 +14,8 @@ # limitations under the License. from __future__ import annotations +from typing import Union + from pydantic import BaseModel, ConfigDict from neo4j_graphrag.experimental.pipeline.component import Component @@ -35,3 +37,12 @@ class ConnectionConfig(BaseModel): class PipelineConfig(BaseModel): components: list[ComponentConfig] connections: list[ConnectionConfig] + + +EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] +RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] +"""Types derived from the SchemaEntity and SchemaRelation types, + so the possible types for dict values are: +- str (for label and description) +- list[dict[str, str]] (for properties) +""" diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index 7c0788eaa..b1b29151d 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -19,7 +19,11 @@ import pytest from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError -from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation +from neo4j_graphrag.experimental.components.schema import ( + SchemaEntity, + SchemaProperty, + SchemaRelation, +) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline @@ -379,3 +383,47 @@ async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> Non assert pipe_inputs["extractor"] == { "lexical_graph_config": lexical_graph_config } + + +def test_knowledge_graph_builder_to_schema_entity_method() -> None: + assert SimpleKGPipeline.to_schema_entity("EntityType") == SchemaEntity( + label="EntityType" + ) + assert SimpleKGPipeline.to_schema_entity({"label": "EntityType"}) == SchemaEntity( + label="EntityType" + ) + assert SimpleKGPipeline.to_schema_entity( + {"label": "EntityType", "description": "A special entity"} + ) == SchemaEntity(label="EntityType", description="A special entity") + assert SimpleKGPipeline.to_schema_entity( + {"label": "EntityType", "properties": []} + ) == SchemaEntity(label="EntityType") + assert SimpleKGPipeline.to_schema_entity( + { + "label": "EntityType", + "properties": [{"name": "entityProperty", "type": "DATE"}], + } + ) == SchemaEntity( + label="EntityType", + properties=[SchemaProperty(name="entityProperty", type="DATE")], + ) + + +def test_knowledge_graph_builder_to_schema_relation_method() -> None: + assert SimpleKGPipeline.to_schema_relation("REL_TYPE") == SchemaRelation( + label="REL_TYPE" + ) + assert SimpleKGPipeline.to_schema_relation({"label": "REL_TYPE"}) == SchemaRelation( + label="REL_TYPE" + ) + assert SimpleKGPipeline.to_schema_relation( + {"label": "REL_TYPE", "description": "A rel type"} + ) == SchemaRelation(label="REL_TYPE", description="A rel type") + assert SimpleKGPipeline.to_schema_relation( + {"label": "REL_TYPE", "properties": []} + ) == SchemaRelation(label="REL_TYPE") + assert SimpleKGPipeline.to_schema_relation( + {"label": "REL_TYPE", "properties": [{"name": "relProperty", "type": "DATE"}]} + ) == SchemaRelation( + label="REL_TYPE", properties=[SchemaProperty(name="relProperty", type="DATE")] + )