Skip to content
Merged
59 changes: 34 additions & 25 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import json
from graphrag_sdk.fixtures.regex import *
import logging
import re
from graphrag_sdk.fixtures.regex import *

logger = logging.getLogger(__name__)

Expand All @@ -15,6 +15,23 @@ class AttributeType:
NUMBER = "number"
BOOLEAN = "boolean"
LIST = "list"
POINT = "point"
MAP = "map"
VECTOR = "vectorf32"


# Synonyms for attribute types
_SYNONYMS = {
"string": STRING,
"integer": NUMBER,
"float": NUMBER,
"number": NUMBER,
"boolean": BOOLEAN,
"list": LIST,
"point": POINT,
"map": MAP,
"vectorf32": VECTOR,
}

@staticmethod
def from_string(txt: str):
Expand All @@ -25,21 +42,19 @@ def from_string(txt: str):
txt (str): The string representation of the attribute type.

Returns:
AttributeType: The corresponding AttributeType value.
str: The corresponding AttributeType value.

Raises:
Exception: If the provided attribute type is invalid.
ValueError: If the provided attribute type is invalid.
"""
if txt.lower() == AttributeType.STRING:
return AttributeType.STRING
if txt.lower() == AttributeType.NUMBER:
return AttributeType.NUMBER
if txt.lower() == AttributeType.BOOLEAN:
return AttributeType.BOOLEAN
if txt.lower() == AttributeType.LIST:
return AttributeType.LIST
raise Exception(f"Invalid attribute type: {txt}")

# Graph representation of the attribute type
normalized_txt = txt.lower()

# Find the matching attribute type
if normalized_txt in AttributeType._SYNONYMS:
return AttributeType._SYNONYMS[normalized_txt]

raise ValueError(f"Invalid attribute type: {txt}")

class Attribute:
""" Represents an attribute of an entity or relation in the ontology.
Expand All @@ -57,15 +72,15 @@ class Attribute:
"""

def __init__(
self, name: str, attr_type: AttributeType, unique: bool, required: bool = False
self, name: str, attr_type: AttributeType, unique: bool = False, required: bool = False
):
"""
Initialize a new Attribute object.

Args:
name (str): The name of the attribute.
attr_type (AttributeType): The type of the attribute.
unique (bool): Indicates whether the attribute should be unique.
unique (bool, optional): Indicates whether the attribute should be unique. Defaults to False.
required (bool, optional): Indicates whether the attribute is required. Defaults to False.
"""
self.name = re.sub(r"([^a-zA-Z0-9_])", "_", name)
Expand Down Expand Up @@ -120,14 +135,6 @@ def from_string(txt: str):
unique = "!" in txt
required = "*" in txt

if attr_type not in [
AttributeType.STRING,
AttributeType.NUMBER,
AttributeType.BOOLEAN,
AttributeType.LIST,
]:
raise Exception(f"Invalid attribute type: {attr_type}")

return Attribute(name, AttributeType.from_string(attr_type), unique, required)

def to_json(self):
Expand All @@ -142,13 +149,15 @@ def to_json(self):
- "unique": A boolean indicating whether the attribute is unique.
- "required": A boolean indicating whether the attribute is required.
"""
return {
json_data = {
"name": self.name,
"type": self.type,
"unique": self.unique,
"required": self.required,
}

return json_data

def __str__(self) -> str:
"""
Returns a string representation of the Attribute object.
Expand Down
37 changes: 34 additions & 3 deletions graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from falkordb import Graph
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.steps.qa_step import QAStep
Expand Down Expand Up @@ -45,8 +46,11 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.model_config = model_config
self.graph = graph
self.ontology = ontology
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json()))


# Filter the ontology to remove unique and required attributes that are not needed for Q&A.
ontology_prompt = self.clean_ontology_for_prompt(ontology)

cypher_system_instruction = cypher_system_instruction.format(ontology=ontology_prompt)

self.cypher_prompt = cypher_gen_prompt
self.qa_prompt = qa_prompt
Expand Down Expand Up @@ -108,4 +112,31 @@ def send_message(self, message: str):
"response": answer,
"context": context,
"cypher": cypher
}
}

def clean_ontology_for_prompt(self, ontology: dict) -> str:
"""
Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt.

Args:
ontology (dict): The ontology to clean and transform.

Returns:
str: The cleaned ontology as a JSON string.
"""
# Convert the ontology object to a JSON.
ontology = ontology.to_json()

# Remove unique and required attributes from the ontology.
for entity in ontology["entities"]:
for attribute in entity["attributes"]:
del attribute['unique']
del attribute['required']

for relation in ontology["relations"]:
for attribute in relation["attributes"]:
del attribute['unique']
del attribute['required']

# Return the transformed ontology as a JSON string
return json.dumps(ontology)
2 changes: 1 addition & 1 deletion graphrag_sdk/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_json(txt: dict | str):
txt.get("description", ""),
)

def to_json(self):
def to_json(self) -> dict:
"""
Convert the entity object to a JSON representation.

Expand Down
14 changes: 7 additions & 7 deletions graphrag_sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ def __init__(
if not isinstance(name, str) or name == "":
raise Exception("name should be a non empty string")

# connect to database
# Connect to database
self.db = FalkorDB(host=host, port=port, username=username, password=password)
self.graph = self.db.select_graph(name)
ontology_graph = self.db.select_graph("{" + name + "}" + "_schema")

# Load / Save ontology to database
if ontology is None:
# load ontology from DB
ontology = Ontology.from_graph(ontology_graph)
# Load ontology from DB
ontology = Ontology.from_schema_graph(ontology_graph)

if len(ontology.entities) == 0:
raise Exception("The ontology is empty. Load a valid ontology or create one using the ontology module.")
else:
# save ontology to DB
# Save ontology to DB
ontology.save_to_graph(ontology_graph)

if ontology is None:
raise Exception("Ontology is not defined")

self._ontology = ontology
self._name = name
Expand Down
Loading
Loading