Skip to content

Commit 2f7fb79

Browse files
committed
feat: add table relation
1 parent 34c2ae5 commit 2f7fb79

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
354354
db_name = table_objs[0].schema
355355
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
356356
tables = []
357+
all_tables = [] # temp save all tables
357358
for obj in table_objs:
358359
schema_table = ''
359360
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
@@ -376,13 +377,60 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
376377
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
377378
schema_table += ",\n".join(field_list)
378379
schema_table += '\n]\n'
379-
tables.append(schema_table)
380+
t_obj = {"id": obj.table.id, "schema_table": schema_table}
381+
tables.append(t_obj)
382+
all_tables.append(t_obj)
380383

381384
# do table embedding
382-
if embedding:
385+
if embedding and tables:
383386
tables = get_table_embedding(session, current_user, tables, question)
387+
# splice schema
388+
if tables:
389+
for s in tables:
390+
schema_str += s.get('schema_table')
391+
392+
# field relation
393+
if tables and ds.table_relation:
394+
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
395+
# Complete the missing table
396+
# get tables in relation, remove irrelevant relation
397+
embedding_table_ids = [s.get('id') for s in tables]
398+
all_relations = list(filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get(
399+
'cell') in embedding_table_ids, relations))
400+
401+
# get relation table ids, sub embedding table ids
402+
relation_table_ids = []
403+
for r in all_relations:
404+
relation_table_ids.append(r.get('source').get('cell'))
405+
relation_table_ids.append(r.get('target').get('cell'))
406+
relation_table_ids = list(set(relation_table_ids))
407+
# get table dict
408+
table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all()
409+
table_dict = {}
410+
for ele in table_records:
411+
table_dict[ele.id] = ele.table_name
412+
413+
# get lost table ids
414+
lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids))
415+
# get lost table schema and splice it
416+
lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables))
417+
if lost_tables:
418+
for s in lost_tables:
419+
schema_str += s.get('schema_table')
420+
421+
# get field dict
422+
relation_field_ids = []
423+
for relation in all_relations:
424+
relation_field_ids.append(relation.get('source').get('port'))
425+
relation_field_ids.append(relation.get('target').get('port'))
426+
relation_field_ids = list(set(relation_field_ids))
427+
field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all()
428+
field_dict = {}
429+
for ele in field_records:
430+
field_dict[ele.id] = ele.field_name
431+
432+
schema_str += '【Foreign keys】\n'
433+
for ele in all_relations:
434+
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"
384435

385-
# todo 外键
386-
for s in tables:
387-
schema_str += s
388436
return schema_str

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from common.utils.utils import SQLBotLogUtil
1111

1212

13-
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
13+
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str):
1414
_list = []
15-
for table_schema in tables:
16-
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
15+
for table in tables:
16+
_list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})
1717

1818
if _list:
1919
try:
20-
text = [s.get('table_schema') for s in _list]
20+
text = [s.get('schema_table') for s in _list]
2121

2222
model = EmbeddingModelCache.get_model()
2323
results = model.embed_documents(text)
@@ -31,7 +31,7 @@ def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables:
3131
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
3232
# print(len(_list))
3333
SQLBotLogUtil.info(json.dumps(_list))
34-
return [t.get("table_schema") for t in _list]
34+
return _list
3535
except Exception:
3636
traceback.print_exc()
3737
return _list

0 commit comments

Comments
 (0)