Skip to content

Commit a9efb6a

Browse files
committed
Add graph traversal
1 parent e377618 commit a9efb6a

File tree

3 files changed

+174
-7
lines changed

3 files changed

+174
-7
lines changed

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,29 @@ def get_index_fields(self) -> list[SearchableField]:
133133
# This is needed to enable semantic searching against the column names as complex field types are not used.
134134
),
135135
ComplexField(
136-
name="ImmediateEntityRelationships",
136+
name="EntityRelationships",
137137
collection=True,
138138
fields=[
139139
SearchableField(
140-
name="Name",
140+
name="ForeignEntity",
141141
type=SearchFieldDataType.String,
142142
),
143143
ComplexField(
144144
name="ForeignKeys",
145145
collection=True,
146146
fields=[
147147
SearchableField(
148-
name="SourceColumnName", type=SearchFieldDataType.String
148+
name="Column", type=SearchFieldDataType.String
149149
),
150150
SearchableField(
151-
name="TargetColumnName", type=SearchFieldDataType.String
151+
name="ForeignColumn", type=SearchFieldDataType.String
152152
),
153153
],
154154
),
155155
],
156156
),
157157
SimpleField(
158-
name="CompleteEntityRelationships",
158+
name="CompleteEntityRelationshipGraph",
159159
type=SearchFieldDataType.String,
160160
collection=True,
161161
),

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,52 @@
1414
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
1515
import random
1616
import re
17+
import networkx as nx
1718

1819
logging.basicConfig(level=logging.INFO)
1920

2021

22+
class ForeignKeyRelationship(BaseModel):
23+
source_column: str = Field(..., alias="Column")
24+
foreign_column: str = Field(..., alias="ForeignColumn")
25+
26+
27+
class EntityRelationship(BaseModel):
28+
foreign_entity: str = Field(..., alias="ForeignEntity")
29+
foreign_keys: list[ForeignKeyRelationship] = Field(..., alias="ForeignKeys")
30+
31+
def pivot(self, entity: str):
32+
"""A method to pivot the entity relationship."""
33+
return EntityRelationship(
34+
foreign_entity=entity,
35+
foreign_keys=[
36+
ForeignKeyRelationship(
37+
source_column=foreign_key.foreign_column,
38+
foreign_column=foreign_key.source_column,
39+
)
40+
for foreign_key in self.foreign_keys
41+
],
42+
)
43+
44+
def add_foreign_key(self, foreign_key: ForeignKeyRelationship):
45+
"""A method to add a foreign key to the entity relationship."""
46+
self.foreign_keys.append(foreign_key)
47+
48+
@classmethod
49+
def from_sql_row(cls, row, columns):
50+
"""A method to create an EntityRelationship from a SQL row."""
51+
result = dict(zip(columns, row))
52+
return cls(
53+
foreign_entity=result["ForeignEntity"],
54+
foreign_keys=[
55+
ForeignKeyRelationship(
56+
source_column=result["Column"],
57+
foreign_column=result["ForeignColumn"],
58+
)
59+
],
60+
)
61+
62+
2163
class ColumnItem(BaseModel):
2264
"""A class to represent a column item."""
2365

@@ -38,7 +80,7 @@ def from_sql_row(cls, row, columns):
3880
result = dict(zip(columns, row))
3981
return cls(
4082
name=result["Name"],
41-
type=result["DataType"],
83+
data_type=result["DataType"],
4284
definition=result["Definition"],
4385
)
4486

@@ -51,6 +93,16 @@ class EntityItem(BaseModel):
5193
name: str = Field(..., alias="Name", exclude=True)
5294
entity_schema: str = Field(..., alias="Schema", exclude=True)
5395
entity_name: Optional[str] = Field(default=None, alias="EntityName")
96+
database: Optional[str] = Field(default=None, alias="Database")
97+
warehouse: Optional[str] = Field(default=None, alias="Warehouse")
98+
99+
entity_relationships: Optional[list[EntityRelationship]] = Field(
100+
None, alias="EntityRelationships"
101+
)
102+
103+
complete_entity_relationship_graph = Optional[str] = Field(
104+
None, alias="CompleteEntityRelationshipGraph"
105+
)
54106

55107
columns: Optional[list[ColumnItem]] = Field(
56108
..., alias="Columns", default_factory=list
@@ -97,6 +149,9 @@ def __init__(
97149
self.single_file = single_file
98150
self.generate_descriptions = generate_descriptions
99151

152+
self.entity_relationships = {}
153+
self.relationship_graph = nx.DiGraph()
154+
100155
load_dotenv(find_dotenv())
101156

102157
@property
@@ -119,6 +174,13 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str:
119174
120175
Must return 3 columns: Name, DataType, Definition."""
121176

177+
@property
178+
@abstractmethod
179+
def extract_entity_relationships_sql_query(self) -> str:
180+
"""An abstract method to extract entity relationships from a database.
181+
182+
Must return 4 columns: Entity, ForeignEntity, Column, ForeignColumn."""
183+
122184
def extract_distinct_values_sql_query(
123185
self, entity: EntityItem, column: ColumnItem
124186
) -> str:
@@ -165,6 +227,72 @@ async def query_entities(
165227

166228
return results
167229

230+
async def extract_entity_relationships(self) -> list[EntityRelationship]:
231+
"""A method to extract entity relationships from a database.
232+
233+
Returns:
234+
list[EntityRelationships]: The list of entity relationships."""
235+
236+
relationships = await self.query_entities(
237+
self.extract_entity_relationships_sql_query, cast_to=EntityRelationship
238+
)
239+
240+
# Duplicate relationships to create a complete graph
241+
242+
for relationship in relationships:
243+
if relationship.entity not in self.entity_relationships:
244+
self.entity_relationships[relationship.entity] = {
245+
relationship.foreign_entity: relationship
246+
}
247+
else:
248+
self.entity_relationships[relationship.entity][
249+
relationship.foreign_entity
250+
].add_foreign_key(relationship.foreign_keys[0])
251+
252+
if relationship.foreign_entity not in self.entity_relationships:
253+
self.entity_relationships[relationship.foreign_entity] = {
254+
relationship.entity: relationship.pivot()
255+
}
256+
else:
257+
self.entity_relationships[relationship.foreign_entity][
258+
relationship.entity
259+
].add_foreign_key(relationship.pivot().foreign_keys[0])
260+
261+
async def build_entity_relationship_graph(self) -> nx.DiGraph:
262+
"""A method to build a complete entity relationship graph."""
263+
264+
for entity, foreign_entities in self.entity_relationships.items():
265+
for foreign_entity, relationship in foreign_entities.items():
266+
self.relationship_graph.add_edge(
267+
entity, foreign_entity, relationship=relationship
268+
)
269+
270+
def get_entity_relationships_from_graph(
271+
self, entity: str, path=None, result=None, visited=None
272+
) -> nx.DiGraph:
273+
if entity not in self.relationship_graph:
274+
return None
275+
276+
if path is None:
277+
path = []
278+
if result is None:
279+
result = []
280+
if visited is None:
281+
visited = set()
282+
283+
# Mark the current node as visited
284+
visited.add(entity)
285+
286+
# For each successor (neighbor in the directed path)
287+
for successor in self.relationship_graph.successors(entity):
288+
new_path = path + [f"{entity} -> {successor}"]
289+
result.append(" -> ".join(new_path)) # Add the path as a string
290+
self.get_entity_relationships_from_graph(
291+
self.relationship_graph, successor, new_path, result, visited
292+
)
293+
294+
return result
295+
168296
async def extract_entities_with_descriptions(self) -> list[EntityItem]:
169297
"""A method to extract entities with descriptions from a database.
170298
@@ -420,12 +548,27 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
420548
if self.generate_descriptions:
421549
await self.generate_entity_description(entity)
422550

551+
# add in relationships
552+
if entity.entity in self.entity_relationships:
553+
entity.entity_relationships = list(
554+
self.entity_relationships[entity.entity].values()
555+
)
556+
557+
# add in the graph traversal
558+
entity.complete_entity_relationship_graph = (
559+
self.get_entity_relationships_from_graph(entity.entity)
560+
)
561+
423562
return entity
424563

425564
async def create_data_dictionary(self):
426565
"""A method to build a data dictionary from a database. Writes to file."""
427566
entities = await self.extract_entities_with_descriptions()
428567

568+
await self.extract_entity_relationships()
569+
570+
await self.build_entity_relationship_graph()
571+
429572
entity_tasks = []
430573
for entity in entities:
431574
entity_tasks.append(self.build_entity_entry(entity))

text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def extract_table_entities_sql_query(self) -> str:
3434
return """SELECT
3535
t.TABLE_NAME AS Entity,
3636
t.TABLE_SCHEMA AS EntitySchema,
37-
CAST(ep.value AS NVARCHAR(500)) AS Description
37+
CAST(ep.value AS NVARCHAR(500)) AS Definition
3838
FROM
3939
INFORMATION_SCHEMA.TABLES t
4040
LEFT JOIN
@@ -80,6 +80,30 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str:
8080
c.TABLE_SCHEMA = '{entity.entity_schema}'
8181
AND c.TABLE_NAME = '{entity.name}';"""
8282

83+
@property
84+
def extract_entity_relationships_sql_query() -> str:
85+
"""A property to extract entity relationships from a SQL Server database."""
86+
return """SELECT
87+
fk_tab.name AS Entity,
88+
pk_tab.name AS ForeignEntity,
89+
fk_col.name AS Column,
90+
pk_col.name AS ForeignColumn
91+
FROM
92+
sys.foreign_keys AS fk
93+
INNER JOIN
94+
sys.foreign_key_columns AS fkc ON fk.object_id = fkc.constraint_object_id
95+
INNER JOIN
96+
sys.tables AS fk_tab ON fk_tab.object_id = fk.parent_object_id
97+
INNER JOIN
98+
sys.tables AS pk_tab ON pk_tab.object_id = fk.referenced_object_id
99+
INNER JOIN
100+
sys.columns AS fk_col ON fkc.parent_object_id = fk_col.object_id AND fkc.parent_column_id = fk_col.column_id
101+
INNER JOIN
102+
sys.columns AS pk_col ON fkc.referenced_object_id = pk_col.object_id AND fkc.referenced_column_id = pk_col.column_id
103+
ORDER BY
104+
Entity, ForeignEntity;
105+
"""
106+
83107

84108
if __name__ == "__main__":
85109
data_dictionary_creator = SqlServerDataDictionaryCreator()

0 commit comments

Comments
 (0)