@@ -392,45 +392,48 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
392392 # field relation
393393 if tables and ds .table_relation :
394394 relations = list (filter (lambda x : x .get ('shape' ) == 'edge' , ds .table_relation ))
395- # Complete the missing table
396- # get tables in relation, remove irrelevant relation
397- embedding_table_ids = [s .get ('id' ) for s in tables ]
398- all_relations = list (filter (lambda x : x .get ('source' ).get ('cell' ) in embedding_table_ids or x .get ('target' ).get (
399- 'cell' ) in embedding_table_ids , relations ))
400-
401- # get relation table ids, sub embedding table ids
402- relation_table_ids = []
403- for r in all_relations :
404- relation_table_ids .append (r .get ('source' ).get ('cell' ))
405- relation_table_ids .append (r .get ('target' ).get ('cell' ))
406- relation_table_ids = list (set (relation_table_ids ))
407- # get table dict
408- table_records = session .query (CoreTable ).filter (CoreTable .id .in_ (list (map (int , relation_table_ids )))).all ()
409- table_dict = {}
410- for ele in table_records :
411- table_dict [ele .id ] = ele .table_name
412-
413- # get lost table ids
414- lost_table_ids = list (set (relation_table_ids ) - set (embedding_table_ids ))
415- # get lost table schema and splice it
416- lost_tables = list (filter (lambda x : x .get ('id' ) in lost_table_ids , all_tables ))
417- if lost_tables :
418- for s in lost_tables :
419- schema_str += s .get ('schema_table' )
420-
421- # get field dict
422- relation_field_ids = []
423- for relation in all_relations :
424- relation_field_ids .append (relation .get ('source' ).get ('port' ))
425- relation_field_ids .append (relation .get ('target' ).get ('port' ))
426- relation_field_ids = list (set (relation_field_ids ))
427- field_records = session .query (CoreField ).filter (CoreField .id .in_ (list (map (int , relation_field_ids )))).all ()
428- field_dict = {}
429- for ele in field_records :
430- field_dict [ele .id ] = ele .field_name
431-
432- schema_str += '【Foreign keys】\n '
433- for ele in all_relations :
434- schema_str += f"{ table_dict .get (int (ele .get ('source' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('source' ).get ('port' )))} ={ table_dict .get (int (ele .get ('target' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('target' ).get ('port' )))} \n "
395+ if relations :
396+ # Complete the missing table
397+ # get tables in relation, remove irrelevant relation
398+ embedding_table_ids = [s .get ('id' ) for s in tables ]
399+ all_relations = list (
400+ filter (lambda x : x .get ('source' ).get ('cell' ) in embedding_table_ids or x .get ('target' ).get (
401+ 'cell' ) in embedding_table_ids , relations ))
402+
403+ # get relation table ids, sub embedding table ids
404+ relation_table_ids = []
405+ for r in all_relations :
406+ relation_table_ids .append (r .get ('source' ).get ('cell' ))
407+ relation_table_ids .append (r .get ('target' ).get ('cell' ))
408+ relation_table_ids = list (set (relation_table_ids ))
409+ # get table dict
410+ table_records = session .query (CoreTable ).filter (CoreTable .id .in_ (list (map (int , relation_table_ids )))).all ()
411+ table_dict = {}
412+ for ele in table_records :
413+ table_dict [ele .id ] = ele .table_name
414+
415+ # get lost table ids
416+ lost_table_ids = list (set (relation_table_ids ) - set (embedding_table_ids ))
417+ # get lost table schema and splice it
418+ lost_tables = list (filter (lambda x : x .get ('id' ) in lost_table_ids , all_tables ))
419+ if lost_tables :
420+ for s in lost_tables :
421+ schema_str += s .get ('schema_table' )
422+
423+ # get field dict
424+ relation_field_ids = []
425+ for relation in all_relations :
426+ relation_field_ids .append (relation .get ('source' ).get ('port' ))
427+ relation_field_ids .append (relation .get ('target' ).get ('port' ))
428+ relation_field_ids = list (set (relation_field_ids ))
429+ field_records = session .query (CoreField ).filter (CoreField .id .in_ (list (map (int , relation_field_ids )))).all ()
430+ field_dict = {}
431+ for ele in field_records :
432+ field_dict [ele .id ] = ele .field_name
433+
434+ if all_relations :
435+ schema_str += '【Foreign keys】\n '
436+ for ele in all_relations :
437+ schema_str += f"{ table_dict .get (int (ele .get ('source' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('source' ).get ('port' )))} ={ table_dict .get (int (ele .get ('target' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('target' ).get ('port' )))} \n "
435438
436439 return schema_str
0 commit comments