Skip to content

Commit 2502455

Browse files
committed
feat: add table relation
1 parent 2f7fb79 commit 2502455

File tree

1 file changed

+43
-40
lines changed

1 file changed

+43
-40
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -392,45 +392,48 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
392392
# field relation
393393
if tables and ds.table_relation:
394394
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
395-
# Complete the missing table
396-
# get tables in relation, remove irrelevant relation
397-
embedding_table_ids = [s.get('id') for s in tables]
398-
all_relations = list(filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get(
399-
'cell') in embedding_table_ids, relations))
400-
401-
# get relation table ids, sub embedding table ids
402-
relation_table_ids = []
403-
for r in all_relations:
404-
relation_table_ids.append(r.get('source').get('cell'))
405-
relation_table_ids.append(r.get('target').get('cell'))
406-
relation_table_ids = list(set(relation_table_ids))
407-
# get table dict
408-
table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all()
409-
table_dict = {}
410-
for ele in table_records:
411-
table_dict[ele.id] = ele.table_name
412-
413-
# get lost table ids
414-
lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids))
415-
# get lost table schema and splice it
416-
lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables))
417-
if lost_tables:
418-
for s in lost_tables:
419-
schema_str += s.get('schema_table')
420-
421-
# get field dict
422-
relation_field_ids = []
423-
for relation in all_relations:
424-
relation_field_ids.append(relation.get('source').get('port'))
425-
relation_field_ids.append(relation.get('target').get('port'))
426-
relation_field_ids = list(set(relation_field_ids))
427-
field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all()
428-
field_dict = {}
429-
for ele in field_records:
430-
field_dict[ele.id] = ele.field_name
431-
432-
schema_str += '【Foreign keys】\n'
433-
for ele in all_relations:
434-
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"
395+
if relations:
396+
# Complete the missing table
397+
# get tables in relation, remove irrelevant relation
398+
embedding_table_ids = [s.get('id') for s in tables]
399+
all_relations = list(
400+
filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get(
401+
'cell') in embedding_table_ids, relations))
402+
403+
# get relation table ids, sub embedding table ids
404+
relation_table_ids = []
405+
for r in all_relations:
406+
relation_table_ids.append(r.get('source').get('cell'))
407+
relation_table_ids.append(r.get('target').get('cell'))
408+
relation_table_ids = list(set(relation_table_ids))
409+
# get table dict
410+
table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all()
411+
table_dict = {}
412+
for ele in table_records:
413+
table_dict[ele.id] = ele.table_name
414+
415+
# get lost table ids
416+
lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids))
417+
# get lost table schema and splice it
418+
lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables))
419+
if lost_tables:
420+
for s in lost_tables:
421+
schema_str += s.get('schema_table')
422+
423+
# get field dict
424+
relation_field_ids = []
425+
for relation in all_relations:
426+
relation_field_ids.append(relation.get('source').get('port'))
427+
relation_field_ids.append(relation.get('target').get('port'))
428+
relation_field_ids = list(set(relation_field_ids))
429+
field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all()
430+
field_dict = {}
431+
for ele in field_records:
432+
field_dict[ele.id] = ele.field_name
433+
434+
if all_relations:
435+
schema_str += '【Foreign keys】\n'
436+
for ele in all_relations:
437+
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"
435438

436439
return schema_str

0 commit comments

Comments
 (0)