Skip to content
75 changes: 54 additions & 21 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import json
from graphrag_sdk.fixtures.regex import *
import logging
import re
from graphrag_sdk.fixtures.regex import *

logger = logging.getLogger(__name__)

Expand All @@ -15,6 +15,28 @@ class AttributeType:
NUMBER = "number"
BOOLEAN = "boolean"
LIST = "list"
POINT = "point"
MAP = "map"
VECTOR = "vectorf32"
DATE = "date"
DATE_TIME = "datetime"
TIME = "time"
DURATION = "duration"

# Synonyms for attribute types
_SYNONYMS = {
STRING: {"string"},
NUMBER: {"integer", "float", "number"},
BOOLEAN: {"boolean"},
LIST: {"list"},
POINT: {"point"},
MAP: {"map"},
VECTOR: {"vectorf32"},
DATE: {"date"},
DATE_TIME: {"datetime", "local datetime"},
TIME: {"time", "local time"},
DURATION: {"duration"},
}

@staticmethod
def from_string(txt: str):
Expand All @@ -28,18 +50,17 @@ def from_string(txt: str):
AttributeType: The corresponding AttributeType value.

Raises:
Exception: If the provided attribute type is invalid.
ValueError: If the provided attribute type is invalid.
"""
if txt.lower() == AttributeType.STRING:
return AttributeType.STRING
if txt.lower() == AttributeType.NUMBER:
return AttributeType.NUMBER
if txt.lower() == AttributeType.BOOLEAN:
return AttributeType.BOOLEAN
if txt.lower() == AttributeType.LIST:
return AttributeType.LIST
raise Exception(f"Invalid attribute type: {txt}")

# Graph representation of the attribute type
normalized_txt = txt.lower()

# Find the matching attribute type
for attr_type, synonyms in AttributeType._SYNONYMS.items():
if normalized_txt in synonyms:
return attr_type

raise ValueError(f"Invalid attribute type: {txt}")

class Attribute:
""" Represents an attribute of an entity or relation in the ontology.
Expand Down Expand Up @@ -120,14 +141,6 @@ def from_string(txt: str):
unique = "!" in txt
required = "*" in txt

if attr_type not in [
AttributeType.STRING,
AttributeType.NUMBER,
AttributeType.BOOLEAN,
AttributeType.LIST,
]:
raise Exception(f"Invalid attribute type: {attr_type}")

return Attribute(name, AttributeType.from_string(attr_type), unique, required)

def to_json(self):
Expand Down Expand Up @@ -161,3 +174,23 @@ def __str__(self) -> str:
str: A string representation of the Attribute object.
"""
return f"{self.name}: \"{self.type}{'!' if self.unique else ''}{'*' if self.required else ''}\""

def process_attributes_from_graph(attributes: list[list[str]]) -> list[Attribute]:
"""
Processes the attributes extracted from the graph and converts them into the SDK convention.

Args:
attributes (list[list[str]]): The attributes extracted from the graph.

Returns:
processed_attributes (list[Attribute]): The processed attributes.
"""
processed_attributes = []
for attr in attributes:
try:
type = AttributeType.from_string(attr[0][1])
processed_attributes.append(Attribute(attr[0][0],type))
except:
continue

return processed_attributes
13 changes: 7 additions & 6 deletions graphrag_sdk/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from falkordb import Graph
from typing import Optional
from .relation import Relation
from .attribute import Attribute
from .attribute import process_attributes_from_graph
from graphrag_sdk.source import AbstractSource
from graphrag_sdk.models import GenerativeModel

Expand Down Expand Up @@ -130,17 +130,18 @@ def from_kg_graph(graph: Graph, sample_size: int = 100,):
for lbls in n_labels:
l = lbls[0]
attributes = graph.query(
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
ontology.add_entity(Entity(l, [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]))
attributes = process_attributes_from_graph(attributes)
ontology.add_entity(Entity(l, attributes))

# Extract attributes for each edge type, limited by the specified sample size.
for e_type in e_types:
e_t = e_type[0]
attributes = graph.query(
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
attributes = [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]
attributes = process_attributes_from_graph(attributes)
for s_lbls in n_labels:
for t_lbls in n_labels:
s_l = s_lbls[0]
Expand Down Expand Up @@ -380,4 +381,4 @@ def save_to_graph(self, graph: Graph):
for relation in self.relations:
query = relation.to_graph_query()
logger.debug(f"Query: {query}")
graph.query(query)
graph.query(query)
Loading