Skip to content
8 changes: 5 additions & 3 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ class Attribute:
"""

def __init__(
self, name: str, attr_type: AttributeType, unique: bool, required: bool = False
self, name: str, attr_type: AttributeType, unique: bool = False, required: bool = False
):
"""
Initialize a new Attribute object.

Args:
name (str): The name of the attribute.
attr_type (AttributeType): The type of the attribute.
unique (bool): Indicates whether the attribute should be unique.
unique (bool, optional): Indicates whether the attribute should be unique. Defaults to False.
required (bool, optional): Indicates whether the attribute is required. Defaults to False.
"""
self.name = re.sub(r"([^a-zA-Z0-9_])", "_", name)
Expand Down Expand Up @@ -142,13 +142,15 @@ def to_json(self):
- "unique": A boolean indicating whether the attribute is unique.
- "required": A boolean indicating whether the attribute is required.
"""
return {
json_data = {
"name": self.name,
"type": self.type,
"unique": self.unique,
"required": self.required,
}

return json_data

def __str__(self) -> str:
"""
Returns a string representation of the Attribute object.
Expand Down
37 changes: 34 additions & 3 deletions graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from falkordb import Graph
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.steps.qa_step import QAStep
Expand Down Expand Up @@ -45,8 +46,11 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.model_config = model_config
self.graph = graph
self.ontology = ontology
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json()))


# Filter the ontology to remove unique and required attributes that are not needed for Q&A.
ontology_prompt = self.clean_ontology_for_prompt(ontology)

cypher_system_instruction = cypher_system_instruction.format(ontology=ontology_prompt)

self.cypher_prompt = cypher_gen_prompt
self.qa_prompt = qa_prompt
Expand Down Expand Up @@ -108,4 +112,31 @@ def send_message(self, message: str):
"response": answer,
"context": context,
"cypher": cypher
}
}

def clean_ontology_for_prompt(self, ontology: dict) -> str:
"""
Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt.

Args:
ontology (dict): The ontology to clean and transform.

Returns:
str: The cleaned ontology as a JSON string.
"""
# Convert the ontology object to a JSON.
ontology = ontology.to_json()

# Remove unique and required attributes from the ontology.
for entity in ontology["entities"]:
for attribute in entity["attributes"]:
del attribute['unique']
del attribute['required']

for relation in ontology["relations"]:
for attribute in relation["attributes"]:
del attribute['unique']
del attribute['required']

# Return the transformed ontology as a JSON string
return json.dumps(ontology)
2 changes: 1 addition & 1 deletion graphrag_sdk/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_json(txt: dict | str):
txt.get("description", ""),
)

def to_json(self):
def to_json(self) -> dict:
"""
Convert the entity object to a JSON representation.

Expand Down
14 changes: 7 additions & 7 deletions graphrag_sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ def __init__(
if not isinstance(name, str) or name == "":
raise Exception("name should be a non empty string")

# connect to database
# Connect to database
self.db = FalkorDB(host=host, port=port, username=username, password=password)
self.graph = self.db.select_graph(name)
ontology_graph = self.db.select_graph("{" + name + "}" + "_schema")

# Load / Save ontology to database
if ontology is None:
# load ontology from DB
ontology = Ontology.from_graph(ontology_graph)
# Load ontology from DB
ontology = Ontology.from_schema_graph(ontology_graph)

if len(ontology.entities) == 0 and len(ontology.relations) == 0:
raise Exception("The ontology is empty. Load a valid ontology or create one using the ontology module.")
else:
# save ontology to DB
# Save ontology to DB
ontology.save_to_graph(ontology_graph)

if ontology is None:
raise Exception("Ontology is not defined")

self._ontology = ontology
self._name = name
Expand Down
90 changes: 68 additions & 22 deletions graphrag_sdk/ontology.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json
from falkordb import Graph
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel
import graphrag_sdk
import logging
from .relation import Relation
import graphrag_sdk
from .entity import Entity
from falkordb import Graph
from typing import Optional
from .relation import Relation
from .attribute import Attribute
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel


logger = logging.getLogger(__name__)


class Ontology(object):
"""
Represents an ontology, which is a collection of entities and relations.
Expand All @@ -25,7 +25,7 @@ def __init__(self, entities: list[Entity] = None, relations: list[Relation] = No
"""
Initialize the Ontology class.

Parameters:
Args:
entities (list[Entity], optional): List of Entity objects. Defaults to None.
relations (list[Relation], optional): List of Relation objects. Defaults to None.
"""
Expand All @@ -42,9 +42,9 @@ def from_sources(
"""
Create an Ontology object from a list of sources.

Parameters:
Args:
sources (list[AbstractSource]): A list of AbstractSource objects representing the sources.
boundaries (Optinal[str]): The boundaries for the ontology.
boundaries (Optional[str]): The boundaries for the ontology.
model (GenerativeModel): The generative model to use.
hide_progress (bool): Whether to hide the progress bar.

Expand All @@ -65,7 +65,7 @@ def from_json(txt: dict | str):
"""
Creates an Ontology object from a JSON representation.

Parameters:
Args:
txt (dict | str): The JSON representation of the ontology. It can be either a dictionary or a string.

Returns:
Expand All @@ -81,11 +81,11 @@ def from_json(txt: dict | str):
)

@staticmethod
def from_graph(graph: Graph):
def from_schema_graph(graph: Graph):
"""
Creates an Ontology object from a given graph.
Creates an Ontology object from a given schema graph.

Parameters:
Args:
graph (Graph): The graph object representing the ontology.

Returns:
Expand All @@ -103,12 +103,58 @@ def from_graph(graph: Graph):
)

return ontology

@staticmethod
def from_kg_graph(graph: Graph, sample_size: int = 100,):
"""
Constructs an Ontology object from a given Knowledge Graph.

This function queries the provided knowledge graph to extract:
1. Entities and their attributes.
2. Relationships between entities and their attributes.

Args:
graph (Graph): The graph object representing the knowledge graph.
sample_size (int): The sample size for the attribute extraction.

Returns:
Ontology: The Ontology object constructed from the Knowledge Graph.
"""
ontology = Ontology()

# Retrieve all node labels and edge types from the graph.
n_labels = graph.call_procedure("db.labels").result_set
e_types = graph.call_procedure("db.relationshipTypes").result_set

# Extract attributes for each node label, limited by the specified sample size.
for lbls in n_labels:
l = lbls[0]
attributes = graph.query(
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1)""").result_set
ontology.add_entity(Entity(l, [Attribute(attr[0][0], attr[0][1]) for attr in attributes]))

# Extract attributes for each edge type, limited by the specified sample size.
for e_type in e_types:
for s_lbls in n_labels:
for t_lbls in n_labels:
e_t = e_type[0]
s_l = s_lbls[0]
t_l = t_lbls[0]
# Check if a relationship exists between the source and target entity labels
if graph.query(f"MATCH (s:{s_l})-[a:{e_t}]->(t:{t_l}) return a limit 1").result_set:
attributes = graph.query(
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1)""").result_set
ontology.add_relation(Relation(e_t, s_l, t_l, [Attribute(attr[0][0], attr[0][1]) for attr in attributes]))

return ontology

def add_entity(self, entity: Entity):
"""
Adds an entity to the ontology.

Parameters:
Args:
entity: The entity object to be added.
"""
self.entities.append(entity)
Expand All @@ -117,12 +163,12 @@ def add_relation(self, relation: Relation):
"""
Adds a relation to the ontology.

Parameters:
Args:
relation (Relation): The relation to be added.
"""
self.relations.append(relation)

def to_json(self):
def to_json(self) -> dict:
"""
Converts the ontology object to a JSON representation.

Expand All @@ -138,7 +184,7 @@ def merge_with(self, o: "Ontology"):
"""
Merges the given ontology `o` with the current ontology.

Parameters:
Args:
o (Ontology): The ontology to merge with.

Returns:
Expand Down Expand Up @@ -259,7 +305,7 @@ def get_entity_with_label(self, label: str):
"""
Retrieves the entity with the specified label.

Parameters:
Args:
label (str): The label of the entity to retrieve.

Returns:
Expand All @@ -271,7 +317,7 @@ def get_relations_with_label(self, label: str):
"""
Returns a list of relations with the specified label.

Parameters:
Args:
label (str): The label to search for.

Returns:
Expand All @@ -283,7 +329,7 @@ def has_entity_with_label(self, label: str):
"""
Checks if the ontology has an entity with the given label.

Parameters:
Args:
label (str): The label to search for.

Returns:
Expand All @@ -295,7 +341,7 @@ def has_relation_with_label(self, label: str):
"""
Checks if the ontology has a relation with the given label.

Parameters:
Args:
label (str): The label of the relation to check.

Returns:
Expand All @@ -321,7 +367,7 @@ def save_to_graph(self, graph: Graph):
"""
Saves the entities and relations to the specified graph.

Parameters:
Args:
graph (Graph): The graph to save the entities and relations to.
"""
for entity in self.entities:
Expand Down
2 changes: 1 addition & 1 deletion graphrag_sdk/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def from_string(txt: str):
[Attribute.from_string(attr) for attr in attributes],
)

def to_json(self):
def to_json(self) -> dict:
"""
Converts the Relation object to a JSON dictionary.

Expand Down