Skip to content

Commit b365e59

Browse files
committed
Update schema store and graph creation
1 parent 5448a0e commit b365e59

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get_index_fields(self) -> list[SearchableField]:
154154
),
155155
],
156156
),
157-
SimpleField(
157+
SearchableField(
158158
name="CompleteEntityRelationshipsGraph",
159159
type=SearchFieldDataType.String,
160160
collection=True,
@@ -303,7 +303,7 @@ def get_indexer(self) -> SearchIndexer:
303303
target_field_name="EntityRelationships",
304304
),
305305
FieldMapping(
306-
source_field_name="/document/CompleteEntityRelationshipsGraph",
306+
source_field_name="/document/CompleteEntityRelationshipsGraph/*",
307307
target_field_name="CompleteEntityRelationshipsGraph",
308308
),
309309
FieldMapping(

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,17 @@ class ForeignKeyRelationship(BaseModel):
2323
column: str = Field(..., alias="Column")
2424
foreign_column: str = Field(..., alias="ForeignColumn")
2525

26-
model_config = ConfigDict(populate_by_name=True,
27-
arbitrary_types_allowed=True)
26+
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
2827

2928

3029
class EntityRelationship(BaseModel):
3130
entity: str = Field(..., alias="Entity", exclude=True)
3231
entity_schema: str = Field(..., alias="Schema", exclude=True)
3332
foreign_entity: str = Field(..., alias="ForeignEntity")
34-
foreign_entity_schema: str = Field(...,
35-
alias="ForeignSchema", exclude=True)
36-
foreign_keys: list[ForeignKeyRelationship] = Field(
37-
..., alias="ForeignKeys")
33+
foreign_entity_schema: str = Field(..., alias="ForeignSchema", exclude=True)
34+
foreign_keys: list[ForeignKeyRelationship] = Field(..., alias="ForeignKeys")
3835

39-
model_config = ConfigDict(populate_by_name=True,
40-
arbitrary_types_allowed=True)
36+
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
4137

4238
def pivot(self):
4339
"""A method to pivot the entity relationship."""
@@ -65,9 +61,11 @@ def from_sql_row(cls, row, columns):
6561
result = dict(zip(columns, row))
6662

6763
entity = "{EntitySchema}.{Entity}".format(
68-
EntitySchema=result['EntitySchema'], Entity=result['Entity'])
64+
EntitySchema=result["EntitySchema"], Entity=result["Entity"]
65+
)
6966
foreign_entity = "{ForeignEntitySchema}.{ForeignEntity}".format(
70-
ForeignEntitySchema=result['ForeignEntitySchema'], ForeignEntity=result['ForeignEntity']
67+
ForeignEntitySchema=result["ForeignEntitySchema"],
68+
ForeignEntity=result["ForeignEntity"],
7169
)
7270
return cls(
7371
entity=entity,
@@ -95,8 +93,7 @@ class ColumnItem(BaseModel):
9593
allowed_values: Optional[list[any]] = Field(None, alias="AllowedValues")
9694
sample_values: Optional[list[any]] = Field(None, alias="SampleValues")
9795

98-
model_config = ConfigDict(populate_by_name=True,
99-
arbitrary_types_allowed=True)
96+
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
10097

10198
@classmethod
10299
def from_sql_row(cls, row, columns):
@@ -121,11 +118,11 @@ class EntityItem(BaseModel):
121118
warehouse: Optional[str] = Field(default=None, alias="Warehouse")
122119

123120
entity_relationships: Optional[list[EntityRelationship]] = Field(
124-
None, alias="EntityRelationships"
121+
alias="EntityRelationships", default_factory=list
125122
)
126123

127124
complete_entity_relationships_graph: Optional[list[str]] = Field(
128-
None, alias="CompleteEntityRelationshipsGraph"
125+
alias="CompleteEntityRelationshipsGraph", default_factory=list
129126
)
130127

131128
columns: Optional[list[ColumnItem]] = Field(
@@ -206,7 +203,8 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str:
206203
def extract_entity_relationships_sql_query(self) -> str:
207204
"""An abstract method to extract entity relationships from a database.
208205
209-
Must return 6 columns: EntitySchema, Entity, ForeignEntitySchema, ForeignEntity, Column, ForeignColumn."""
206+
Must return 6 columns: EntitySchema, Entity, ForeignEntitySchema, ForeignEntity, Column, ForeignColumn.
207+
"""
210208

211209
def extract_distinct_values_sql_query(
212210
self, entity: EntityItem, column: ColumnItem
@@ -272,7 +270,10 @@ async def extract_entity_relationships(self) -> list[EntityRelationship]:
272270
relationship.foreign_entity: relationship
273271
}
274272
else:
275-
if relationship.foreign_entity not in self.entity_relationships[relationship.entity]:
273+
if (
274+
relationship.foreign_entity
275+
not in self.entity_relationships[relationship.entity]
276+
):
276277
self.entity_relationships[relationship.entity][
277278
relationship.foreign_entity
278279
] = relationship
@@ -286,7 +287,10 @@ async def extract_entity_relationships(self) -> list[EntityRelationship]:
286287
relationship.entity: relationship.pivot()
287288
}
288289
else:
289-
if relationship.entity not in self.entity_relationships[relationship.foreign_entity]:
290+
if (
291+
relationship.entity
292+
not in self.entity_relationships[relationship.foreign_entity]
293+
):
290294
self.entity_relationships[relationship.foreign_entity][
291295
relationship.entity
292296
] = relationship.pivot()
@@ -308,7 +312,7 @@ def get_entity_relationships_from_graph(
308312
self, entity: str, path=None, result=None, visited=None
309313
) -> nx.DiGraph:
310314
if entity not in self.relationship_graph:
311-
return None
315+
return []
312316

313317
if path is None:
314318
path = [entity]
@@ -320,14 +324,20 @@ def get_entity_relationships_from_graph(
320324
# Mark the current node as visited
321325
visited.add(entity)
322326

323-
# For each successor (neighbor in the directed path)
324-
for successor in self.relationship_graph.successors(entity):
325-
if successor not in visited:
327+
successors = list(self.relationship_graph.successors(entity))
328+
successors_not_visited = [
329+
successor for successor in successors if successor not in visited
330+
]
331+
if len(successors_not_visited) == 0 and len(path) > 1:
332+
# Add the complete path to the result as a string
333+
result.append(" -> ".join(path))
334+
else:
335+
# For each successor (neighbor in the directed path)
336+
for successor in successors_not_visited:
326337
new_path = path + [successor]
327338
# Add the path as a string
328-
result.append(" -> ".join(new_path))
329339
self.get_entity_relationships_from_graph(
330-
successor, new_path, result, visited
340+
successor, new_path, result, visited.copy()
331341
)
332342

333343
return result
@@ -423,8 +433,7 @@ async def generate_column_definition(self, entity: EntityItem, column: ColumnIte
423433
If you think the sample values belong to a specific standard, you can mention it in the definition. e.g. The column contains a list of country codes in the ISO 3166-1 alpha-2 format. 'US' for United States, 'GB' for United Kingdom, 'FR' for France. Including the specific standard format code can help the user understand the data better.
424434
425435
If you think the sample values are not representative of the column as a whole, you can provide a more general definition of the column without mentioning the sample values."""
426-
stringifed_sample_values = [str(value)
427-
for value in column.sample_values]
436+
stringifed_sample_values = [str(value) for value in column.sample_values]
428437

429438
column_definition_input = f"""Describe the {column.name} column in the {entity.entity} entity. The following sample values are provided from {
430439
column.name}: {', '.join(stringifed_sample_values)}."""
@@ -469,9 +478,7 @@ async def extract_columns_with_definitions(
469478
)
470479

471480
if self.generate_definitions:
472-
definition_tasks.append(
473-
self.generate_column_definition(entity, column)
474-
)
481+
definition_tasks.append(self.generate_column_definition(entity, column))
475482

476483
await asyncio.gather(*distinct_value_tasks)
477484

0 commit comments

Comments
 (0)