1414from azure .identity import DefaultAzureCredential , get_bearer_token_provider
1515import random
1616import re
17+ import networkx as nx
1718
1819logging .basicConfig (level = logging .INFO )
1920
2021
22+ class ForeignKeyRelationship (BaseModel ):
23+ source_column : str = Field (..., alias = "Column" )
24+ foreign_column : str = Field (..., alias = "ForeignColumn" )
25+
26+
27+ class EntityRelationship (BaseModel ):
28+ foreign_entity : str = Field (..., alias = "ForeignEntity" )
29+ foreign_keys : list [ForeignKeyRelationship ] = Field (..., alias = "ForeignKeys" )
30+
31+ def pivot (self , entity : str ):
32+ """A method to pivot the entity relationship."""
33+ return EntityRelationship (
34+ foreign_entity = entity ,
35+ foreign_keys = [
36+ ForeignKeyRelationship (
37+ source_column = foreign_key .foreign_column ,
38+ foreign_column = foreign_key .source_column ,
39+ )
40+ for foreign_key in self .foreign_keys
41+ ],
42+ )
43+
44+ def add_foreign_key (self , foreign_key : ForeignKeyRelationship ):
45+ """A method to add a foreign key to the entity relationship."""
46+ self .foreign_keys .append (foreign_key )
47+
48+ @classmethod
49+ def from_sql_row (cls , row , columns ):
50+ """A method to create an EntityRelationship from a SQL row."""
51+ result = dict (zip (columns , row ))
52+ return cls (
53+ foreign_entity = result ["ForeignEntity" ],
54+ foreign_keys = [
55+ ForeignKeyRelationship (
56+ source_column = result ["Column" ],
57+ foreign_column = result ["ForeignColumn" ],
58+ )
59+ ],
60+ )
61+
62+
2163class ColumnItem (BaseModel ):
2264 """A class to represent a column item."""
2365
@@ -38,7 +80,7 @@ def from_sql_row(cls, row, columns):
3880 result = dict (zip (columns , row ))
3981 return cls (
4082 name = result ["Name" ],
41- type = result ["DataType" ],
83+ data_type = result ["DataType" ],
4284 definition = result ["Definition" ],
4385 )
4486
@@ -51,6 +93,16 @@ class EntityItem(BaseModel):
5193 name : str = Field (..., alias = "Name" , exclude = True )
5294 entity_schema : str = Field (..., alias = "Schema" , exclude = True )
5395 entity_name : Optional [str ] = Field (default = None , alias = "EntityName" )
96+ database : Optional [str ] = Field (default = None , alias = "Database" )
97+ warehouse : Optional [str ] = Field (default = None , alias = "Warehouse" )
98+
99+ entity_relationships : Optional [list [EntityRelationship ]] = Field (
100+ None , alias = "EntityRelationships"
101+ )
102+
103+ complete_entity_relationship_graph = Optional [str ] = Field (
104+ None , alias = "CompleteEntityRelationshipGraph"
105+ )
54106
55107 columns : Optional [list [ColumnItem ]] = Field (
56108 ..., alias = "Columns" , default_factory = list
@@ -97,6 +149,9 @@ def __init__(
97149 self .single_file = single_file
98150 self .generate_descriptions = generate_descriptions
99151
152+ self .entity_relationships = {}
153+ self .relationship_graph = nx .DiGraph ()
154+
100155 load_dotenv (find_dotenv ())
101156
102157 @property
@@ -119,6 +174,13 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str:
119174
120175 Must return 3 columns: Name, DataType, Definition."""
121176
177+ @property
178+ @abstractmethod
179+ def extract_entity_relationships_sql_query (self ) -> str :
180+ """An abstract method to extract entity relationships from a database.
181+
182+ Must return 4 columns: Entity, ForeignEntity, Column, ForeignColumn."""
183+
122184 def extract_distinct_values_sql_query (
123185 self , entity : EntityItem , column : ColumnItem
124186 ) -> str :
@@ -165,6 +227,72 @@ async def query_entities(
165227
166228 return results
167229
230+ async def extract_entity_relationships (self ) -> list [EntityRelationship ]:
231+ """A method to extract entity relationships from a database.
232+
233+ Returns:
234+ list[EntityRelationships]: The list of entity relationships."""
235+
236+ relationships = await self .query_entities (
237+ self .extract_entity_relationships_sql_query , cast_to = EntityRelationship
238+ )
239+
240+ # Duplicate relationships to create a complete graph
241+
242+ for relationship in relationships :
243+ if relationship .entity not in self .entity_relationships :
244+ self .entity_relationships [relationship .entity ] = {
245+ relationship .foreign_entity : relationship
246+ }
247+ else :
248+ self .entity_relationships [relationship .entity ][
249+ relationship .foreign_entity
250+ ].add_foreign_key (relationship .foreign_keys [0 ])
251+
252+ if relationship .foreign_entity not in self .entity_relationships :
253+ self .entity_relationships [relationship .foreign_entity ] = {
254+ relationship .entity : relationship .pivot ()
255+ }
256+ else :
257+ self .entity_relationships [relationship .foreign_entity ][
258+ relationship .entity
259+ ].add_foreign_key (relationship .pivot ().foreign_keys [0 ])
260+
261+ async def build_entity_relationship_graph (self ) -> nx .DiGraph :
262+ """A method to build a complete entity relationship graph."""
263+
264+ for entity , foreign_entities in self .entity_relationships .items ():
265+ for foreign_entity , relationship in foreign_entities .items ():
266+ self .relationship_graph .add_edge (
267+ entity , foreign_entity , relationship = relationship
268+ )
269+
270+ def get_entity_relationships_from_graph (
271+ self , entity : str , path = None , result = None , visited = None
272+ ) -> nx .DiGraph :
273+ if entity not in self .relationship_graph :
274+ return None
275+
276+ if path is None :
277+ path = []
278+ if result is None :
279+ result = []
280+ if visited is None :
281+ visited = set ()
282+
283+ # Mark the current node as visited
284+ visited .add (entity )
285+
286+ # For each successor (neighbor in the directed path)
287+ for successor in self .relationship_graph .successors (entity ):
288+ new_path = path + [f"{ entity } -> { successor } " ]
289+ result .append (" -> " .join (new_path )) # Add the path as a string
290+ self .get_entity_relationships_from_graph (
291+ self .relationship_graph , successor , new_path , result , visited
292+ )
293+
294+ return result
295+
168296 async def extract_entities_with_descriptions (self ) -> list [EntityItem ]:
169297 """A method to extract entities with descriptions from a database.
170298
@@ -420,12 +548,27 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
420548 if self .generate_descriptions :
421549 await self .generate_entity_description (entity )
422550
551+ # add in relationships
552+ if entity .entity in self .entity_relationships :
553+ entity .entity_relationships = list (
554+ self .entity_relationships [entity .entity ].values ()
555+ )
556+
557+ # add in the graph traversal
558+ entity .complete_entity_relationship_graph = (
559+ self .get_entity_relationships_from_graph (entity .entity )
560+ )
561+
423562 return entity
424563
425564 async def create_data_dictionary (self ):
426565 """A method to build a data dictionary from a database. Writes to file."""
427566 entities = await self .extract_entities_with_descriptions ()
428567
568+ await self .extract_entity_relationships ()
569+
570+ await self .build_entity_relationship_graph ()
571+
429572 entity_tasks = []
430573 for entity in entities :
431574 entity_tasks .append (self .build_entity_entry (entity ))
0 commit comments