Skip to content
19 changes: 13 additions & 6 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,31 @@ def from_string(txt: str):

return Attribute(name, AttributeType.from_string(attr_type), unique, required)

def to_json(self):
def to_json(self, include_all: bool = True):
"""
Converts the attribute object to a JSON representation.

Args:
include_all (bool): Whether to include both "unique" and "required" fields in the output. Default is True.

Returns:
dict: A dictionary representing the attribute object in JSON format.
The dictionary contains the following keys:
- "name": The name of the attribute.
- "type": The type of the attribute.
- "unique": A boolean indicating whether the attribute is unique.
- "required": A boolean indicating whether the attribute is required.
Optionally includes:
- "unique": A boolean indicating whether the attribute is unique (if include_all is True).
- "required": A boolean indicating whether the attribute is required (if include_all is True).
"""
return {
json_data = {
"name": self.name,
"type": self.type,
"unique": self.unique,
"required": self.required,
}
if include_all:
json_data["unique"] = self.unique
json_data["required"] = self.required

return json_data

def __str__(self) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.model_config = model_config
self.graph = graph
self.ontology = ontology
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json()))
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json(include_all=False)))


self.cypher_prompt = cypher_gen_prompt
Expand Down
4 changes: 2 additions & 2 deletions graphrag_sdk/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_json(txt: dict | str):
txt.get("description", ""),
)

def to_json(self):
def to_json(self, include_all: bool = True) -> dict:
"""
Convert the entity object to a JSON representation.

Expand All @@ -95,7 +95,7 @@ def to_json(self):
"""
return {
"label": self.label,
"attributes": [attr.to_json() for attr in self.attributes],
"attributes": [attr.to_json(include_all=include_all) for attr in self.attributes],
"description": self.description,
}

Expand Down
2 changes: 1 addition & 1 deletion graphrag_sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
# Load / Save ontology to database
if ontology is None:
# load ontology from DB
ontology = Ontology.from_graph(ontology_graph)
ontology = Ontology.from_schema_graph(ontology_graph)
else:
# save ontology to DB
ontology.save_to_graph(ontology_graph)
Expand Down
77 changes: 53 additions & 24 deletions graphrag_sdk/ontology.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import json
from falkordb import Graph
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel
import graphrag_sdk
import logging
from .relation import Relation
import graphrag_sdk
from .entity import Entity
from falkordb import Graph
from typing import Optional
from falkordb import FalkorDB
from .relation import Relation
from .attribute import Attribute
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel


logger = logging.getLogger(__name__)


class Ontology(object):
"""
Represents an ontology, which is a collection of entities and relations.
Expand All @@ -25,7 +26,7 @@ def __init__(self, entities: list[Entity] = None, relations: list[Relation] = No
"""
Initialize the Ontology class.

Parameters:
Args:
entities (list[Entity], optional): List of Entity objects. Defaults to None.
relations (list[Relation], optional): List of Relation objects. Defaults to None.
"""
Expand All @@ -42,9 +43,9 @@ def from_sources(
"""
Create an Ontology object from a list of sources.

Parameters:
Args:
sources (list[AbstractSource]): A list of AbstractSource objects representing the sources.
boundaries (Optinal[str]): The boundaries for the ontology.
boundaries (Optional[str]): The boundaries for the ontology.
model (GenerativeModel): The generative model to use.
hide_progress (bool): Whether to hide the progress bar.

Expand All @@ -65,7 +66,7 @@ def from_json(txt: dict | str):
"""
Creates an Ontology object from a JSON representation.

Parameters:
Args:
txt (dict | str): The JSON representation of the ontology. It can be either a dictionary or a string.

Returns:
Expand All @@ -81,11 +82,11 @@ def from_json(txt: dict | str):
)

@staticmethod
def from_graph(graph: Graph):
def from_schema_graph(graph: Graph):
"""
Creates an Ontology object from a given graph.
Creates an Ontology object from a given schema graph.

Parameters:
Args:
graph (Graph): The graph object representing the ontology.

Returns:
Expand All @@ -103,12 +104,40 @@ def from_graph(graph: Graph):
)

return ontology

@staticmethod
def from_kg_graph(graph: Graph, node_limit: int = 100,):
"""
Creates an Ontology object from a given Knowledge Graph.

Args:
graph (Graph): The graph object representing the ontology.

Returns:
The Ontology object created from the Knowledge Graph.
"""
ontology = Ontology()

e_labels = graph.query("call db.labels()").result_set

for label in e_labels:
attributes = graph.query(f"MATCH (a:{label[0]}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }} with types limit {node_limit} unwind types as kt return kt, count(1)").result_set
ontology.add_entity(Entity(label[0], [Attribute(attr[0][0], attr[0][1], False, False) for attr in attributes]))

r_labels = graph.query("call db.relationshipTypes()").result_set
for label in r_labels:
for label_s in e_labels:
for label_t in e_labels:
if graph.query(f"MATCH (s:{label_s[0]})-[a:{label[0]}]->(t:{label_t[0]}) return a limit 1").result_set:
attributes = graph.query(f"MATCH ()-[a:{label[0]}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }} with types limit {node_limit} unwind types as kt return kt, count(1)").result_set
ontology.add_relation(Relation(label[0], label_s[0], label_t[0], [Attribute(attr[0][0], attr[0][1], False, False) for attr in attributes]))
return ontology

def add_entity(self, entity: Entity):
"""
Adds an entity to the ontology.

Parameters:
Args:
entity: The entity object to be added.
"""
self.entities.append(entity)
Expand All @@ -117,28 +146,28 @@ def add_relation(self, relation: Relation):
"""
Adds a relation to the ontology.

Parameters:
Args:
relation (Relation): The relation to be added.
"""
self.relations.append(relation)

def to_json(self):
def to_json(self, include_all: bool = True) -> dict:
"""
Converts the ontology object to a JSON representation.

Returns:
A dictionary representing the ontology object in JSON format.
"""
return {
"entities": [entity.to_json() for entity in self.entities],
"relations": [relation.to_json() for relation in self.relations],
"entities": [entity.to_json(include_all) for entity in self.entities],
"relations": [relation.to_json(include_all) for relation in self.relations],
}

def merge_with(self, o: "Ontology"):
"""
Merges the given ontology `o` with the current ontology.

Parameters:
Args:
o (Ontology): The ontology to merge with.

Returns:
Expand Down Expand Up @@ -259,7 +288,7 @@ def get_entity_with_label(self, label: str):
"""
Retrieves the entity with the specified label.

Parameters:
Args:
label (str): The label of the entity to retrieve.

Returns:
Expand All @@ -271,7 +300,7 @@ def get_relations_with_label(self, label: str):
"""
Returns a list of relations with the specified label.

Parameters:
Args:
label (str): The label to search for.

Returns:
Expand All @@ -283,7 +312,7 @@ def has_entity_with_label(self, label: str):
"""
Checks if the ontology has an entity with the given label.

Parameters:
Args:
label (str): The label to search for.

Returns:
Expand All @@ -295,7 +324,7 @@ def has_relation_with_label(self, label: str):
"""
Checks if the ontology has a relation with the given label.

Parameters:
Args:
label (str): The label of the relation to check.

Returns:
Expand All @@ -321,7 +350,7 @@ def save_to_graph(self, graph: Graph):
"""
Saves the entities and relations to the specified graph.

Parameters:
Args:
graph (Graph): The graph to save the entities and relations to.
"""
for entity in self.entities:
Expand Down
4 changes: 2 additions & 2 deletions graphrag_sdk/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def from_string(txt: str):
[Attribute.from_string(attr) for attr in attributes],
)

def to_json(self):
def to_json(self, include_all: bool = True) -> dict:
"""
Converts the Relation object to a JSON dictionary.

Expand All @@ -216,7 +216,7 @@ def to_json(self):
"label": self.label,
"source": self.source.to_json(),
"target": self.target.to_json(),
"attributes": [attr.to_json() for attr in self.attributes],
"attributes": [attr.to_json(include_all=include_all) for attr in self.attributes],
}

def combine(self, relation2: "Relation"):
Expand Down