@@ -23,21 +23,17 @@ class ForeignKeyRelationship(BaseModel):
2323 column : str = Field (..., alias = "Column" )
2424 foreign_column : str = Field (..., alias = "ForeignColumn" )
2525
26- model_config = ConfigDict (populate_by_name = True ,
27- arbitrary_types_allowed = True )
26+ model_config = ConfigDict (populate_by_name = True , arbitrary_types_allowed = True )
2827
2928
3029class EntityRelationship (BaseModel ):
3130 entity : str = Field (..., alias = "Entity" , exclude = True )
3231 entity_schema : str = Field (..., alias = "Schema" , exclude = True )
3332 foreign_entity : str = Field (..., alias = "ForeignEntity" )
34- foreign_entity_schema : str = Field (...,
35- alias = "ForeignSchema" , exclude = True )
36- foreign_keys : list [ForeignKeyRelationship ] = Field (
37- ..., alias = "ForeignKeys" )
33+ foreign_entity_schema : str = Field (..., alias = "ForeignSchema" , exclude = True )
34+ foreign_keys : list [ForeignKeyRelationship ] = Field (..., alias = "ForeignKeys" )
3835
39- model_config = ConfigDict (populate_by_name = True ,
40- arbitrary_types_allowed = True )
36+ model_config = ConfigDict (populate_by_name = True , arbitrary_types_allowed = True )
4137
4238 def pivot (self ):
4339 """A method to pivot the entity relationship."""
@@ -65,9 +61,11 @@ def from_sql_row(cls, row, columns):
6561 result = dict (zip (columns , row ))
6662
6763 entity = "{EntitySchema}.{Entity}" .format (
68- EntitySchema = result ['EntitySchema' ], Entity = result ['Entity' ])
64+ EntitySchema = result ["EntitySchema" ], Entity = result ["Entity" ]
65+ )
6966 foreign_entity = "{ForeignEntitySchema}.{ForeignEntity}" .format (
70- ForeignEntitySchema = result ['ForeignEntitySchema' ], ForeignEntity = result ['ForeignEntity' ]
67+ ForeignEntitySchema = result ["ForeignEntitySchema" ],
68+ ForeignEntity = result ["ForeignEntity" ],
7169 )
7270 return cls (
7371 entity = entity ,
@@ -95,8 +93,7 @@ class ColumnItem(BaseModel):
9593 allowed_values : Optional [list [any ]] = Field (None , alias = "AllowedValues" )
9694 sample_values : Optional [list [any ]] = Field (None , alias = "SampleValues" )
9795
98- model_config = ConfigDict (populate_by_name = True ,
99- arbitrary_types_allowed = True )
96+ model_config = ConfigDict (populate_by_name = True , arbitrary_types_allowed = True )
10097
10198 @classmethod
10299 def from_sql_row (cls , row , columns ):
@@ -121,11 +118,11 @@ class EntityItem(BaseModel):
121118 warehouse : Optional [str ] = Field (default = None , alias = "Warehouse" )
122119
123120 entity_relationships : Optional [list [EntityRelationship ]] = Field (
124- None , alias = "EntityRelationships"
121+ alias = "EntityRelationships" , default_factory = list
125122 )
126123
127124 complete_entity_relationships_graph : Optional [list [str ]] = Field (
128- None , alias = "CompleteEntityRelationshipsGraph"
125+ alias = "CompleteEntityRelationshipsGraph" , default_factory = list
129126 )
130127
131128 columns : Optional [list [ColumnItem ]] = Field (
@@ -206,7 +203,8 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str:
206203 def extract_entity_relationships_sql_query (self ) -> str :
207204 """An abstract method to extract entity relationships from a database.
208205
209- Must return 6 columns: EntitySchema, Entity, ForeignEntitySchema, ForeignEntity, Column, ForeignColumn."""
206+ Must return 6 columns: EntitySchema, Entity, ForeignEntitySchema, ForeignEntity, Column, ForeignColumn.
207+ """
210208
211209 def extract_distinct_values_sql_query (
212210 self , entity : EntityItem , column : ColumnItem
@@ -272,7 +270,10 @@ async def extract_entity_relationships(self) -> list[EntityRelationship]:
272270 relationship .foreign_entity : relationship
273271 }
274272 else :
275- if relationship .foreign_entity not in self .entity_relationships [relationship .entity ]:
273+ if (
274+ relationship .foreign_entity
275+ not in self .entity_relationships [relationship .entity ]
276+ ):
276277 self .entity_relationships [relationship .entity ][
277278 relationship .foreign_entity
278279 ] = relationship
@@ -286,7 +287,10 @@ async def extract_entity_relationships(self) -> list[EntityRelationship]:
286287 relationship .entity : relationship .pivot ()
287288 }
288289 else :
289- if relationship .entity not in self .entity_relationships [relationship .foreign_entity ]:
290+ if (
291+ relationship .entity
292+ not in self .entity_relationships [relationship .foreign_entity ]
293+ ):
290294 self .entity_relationships [relationship .foreign_entity ][
291295 relationship .entity
292296 ] = relationship .pivot ()
@@ -308,7 +312,7 @@ def get_entity_relationships_from_graph(
308312 self , entity : str , path = None , result = None , visited = None
309313 ) -> nx .DiGraph :
310314 if entity not in self .relationship_graph :
311- return None
315+ return []
312316
313317 if path is None :
314318 path = [entity ]
@@ -320,14 +324,20 @@ def get_entity_relationships_from_graph(
320324 # Mark the current node as visited
321325 visited .add (entity )
322326
323- # For each successor (neighbor in the directed path)
324- for successor in self .relationship_graph .successors (entity ):
325- if successor not in visited :
327+ successors = list (self .relationship_graph .successors (entity ))
328+ successors_not_visited = [
329+ successor for successor in successors if successor not in visited
330+ ]
331+ if len (successors_not_visited ) == 0 and len (path ) > 1 :
332+ # Add the complete path to the result as a string
333+ result .append (" -> " .join (path ))
334+ else :
335+ # For each successor (neighbor in the directed path)
336+ for successor in successors_not_visited :
326337 new_path = path + [successor ]
327338 # Add the path as a string
328- result .append (" -> " .join (new_path ))
329339 self .get_entity_relationships_from_graph (
330- successor , new_path , result , visited
340+ successor , new_path , result , visited . copy ()
331341 )
332342
333343 return result
@@ -423,8 +433,7 @@ async def generate_column_definition(self, entity: EntityItem, column: ColumnIte
423433 If you think the sample values belong to a specific standard, you can mention it in the definition. e.g. The column contains a list of country codes in the ISO 3166-1 alpha-2 format. 'US' for United States, 'GB' for United Kingdom, 'FR' for France. Including the specific standard format code can help the user understand the data better.
424434
425435 If you think the sample values are not representative of the column as a whole, you can provide a more general definition of the column without mentioning the sample values."""
426- stringifed_sample_values = [str (value )
427- for value in column .sample_values ]
436+ stringifed_sample_values = [str (value ) for value in column .sample_values ]
428437
429438 column_definition_input = f"""Describe the { column .name } column in the { entity .entity } entity. The following sample values are provided from {
430439 column .name } : { ', ' .join (stringifed_sample_values )} ."""
@@ -469,9 +478,7 @@ async def extract_columns_with_definitions(
469478 )
470479
471480 if self .generate_definitions :
472- definition_tasks .append (
473- self .generate_column_definition (entity , column )
474- )
481+ definition_tasks .append (self .generate_column_definition (entity , column ))
475482
476483 await asyncio .gather (* distinct_value_tasks )
477484
0 commit comments