Skip to content
51 changes: 29 additions & 22 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import json
from graphrag_sdk.fixtures.regex import *
import logging
import re
from graphrag_sdk.fixtures.regex import *

logger = logging.getLogger(__name__)

Expand All @@ -15,6 +15,23 @@ class AttributeType:
NUMBER = "number"
BOOLEAN = "boolean"
LIST = "list"
POINT = "point"
MAP = "map"
VECTOR = "vectorf32"


# Synonyms for attribute types
_SYNONYMS = {
"string": STRING,
"integer": NUMBER,
"float": NUMBER,
"number": NUMBER,
"boolean": BOOLEAN,
"list": LIST,
"point": POINT,
"map": MAP,
"vectorf32": VECTOR,
}

@staticmethod
def from_string(txt: str):
Expand All @@ -25,21 +42,19 @@ def from_string(txt: str):
txt (str): The string representation of the attribute type.

Returns:
AttributeType: The corresponding AttributeType value.
str: The corresponding AttributeType value.

Raises:
Exception: If the provided attribute type is invalid.
ValueError: If the provided attribute type is invalid.
"""
if txt.lower() == AttributeType.STRING:
return AttributeType.STRING
if txt.lower() == AttributeType.NUMBER:
return AttributeType.NUMBER
if txt.lower() == AttributeType.BOOLEAN:
return AttributeType.BOOLEAN
if txt.lower() == AttributeType.LIST:
return AttributeType.LIST
raise Exception(f"Invalid attribute type: {txt}")

# Graph representation of the attribute type
normalized_txt = txt.lower()

# Find the matching attribute type
if normalized_txt in AttributeType._SYNONYMS:
return AttributeType._SYNONYMS[normalized_txt]

raise ValueError(f"Invalid attribute type: {txt}")

class Attribute:
""" Represents an attribute of an entity or relation in the ontology.
Expand Down Expand Up @@ -120,14 +135,6 @@ def from_string(txt: str):
unique = "!" in txt
required = "*" in txt

if attr_type not in [
AttributeType.STRING,
AttributeType.NUMBER,
AttributeType.BOOLEAN,
AttributeType.LIST,
]:
raise Exception(f"Invalid attribute type: {attr_type}")

return Attribute(name, AttributeType.from_string(attr_type), unique, required)

def to_json(self):
Expand Down
37 changes: 31 additions & 6 deletions graphrag_sdk/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from falkordb import Graph
from typing import Optional
from .relation import Relation
from .attribute import Attribute
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel
from .attribute import Attribute, AttributeType



logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,17 +131,18 @@ def from_kg_graph(graph: Graph, sample_size: int = 100,):
for lbls in n_labels:
l = lbls[0]
attributes = graph.query(
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
ontology.add_entity(Entity(l, [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]))
attributes = ontology.process_attributes_from_graph(attributes)
ontology.add_entity(Entity(l, attributes))

# Extract attributes for each edge type, limited by the specified sample size.
for e_type in e_types:
e_t = e_type[0]
attributes = graph.query(
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
attributes = [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]
attributes = ontology.process_attributes_from_graph(attributes)
for s_lbls in n_labels:
for t_lbls in n_labels:
s_l = s_lbls[0]
Expand All @@ -152,6 +154,29 @@ def from_kg_graph(graph: Graph, sample_size: int = 100,):

return ontology

@staticmethod
def process_attributes_from_graph(attributes: list[list[list[str]]]) -> list[Attribute]:
"""
Processes the attributes extracted from the graph and converts them into the SDK convention.

Args:
attributes (list[list[list[str]]]): The attributes extracted from the graph.

Returns:
processed_attributes (list[Attribute]): The processed attributes.
"""
processed_attributes = []
for attr in attributes:
attr_name, attr_type = attr[0]
try:
attr_type = AttributeType.from_string(attr_type)
except:
continue

processed_attributes.append(Attribute(attr_name, attr_type))

return processed_attributes

def add_entity(self, entity: Entity):
"""
Adds an entity to the ontology.
Expand Down Expand Up @@ -380,4 +405,4 @@ def save_to_graph(self, graph: Graph):
for relation in self.relations:
query = relation.to_graph_query()
logger.debug(f"Query: {query}")
graph.query(query)
graph.query(query)
Loading