@@ -354,6 +354,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
354354 db_name = table_objs [0 ].schema
355355 schema_str += f"【DB_ID】 { db_name } \n 【Schema】\n "
356356 tables = []
357+ all_tables = [] # temp save all tables
357358 for obj in table_objs :
358359 schema_table = ''
359360 schema_table += f"# Table: { db_name } .{ obj .table .table_name } " if ds .type != "mysql" and ds .type != "es" else f"# Table: { obj .table .table_name } "
@@ -376,13 +377,60 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
376377 field_list .append (f"({ field .field_name } :{ field .field_type } , { field_comment } )" )
377378 schema_table += ",\n " .join (field_list )
378379 schema_table += '\n ]\n '
379- tables .append (schema_table )
380+ t_obj = {"id" : obj .table .id , "schema_table" : schema_table }
381+ tables .append (t_obj )
382+ all_tables .append (t_obj )
380383
381384 # do table embedding
382- if embedding :
385+ if embedding and tables :
383386 tables = get_table_embedding (session , current_user , tables , question )
387+ # splice schema
388+ if tables :
389+ for s in tables :
390+ schema_str += s .get ('schema_table' )
391+
392+ # field relation
393+ if tables and ds .table_relation :
394+ relations = list (filter (lambda x : x .get ('shape' ) == 'edge' , ds .table_relation ))
395+ # Complete the missing table
396+ # get tables in relation, remove irrelevant relation
397+ embedding_table_ids = [s .get ('id' ) for s in tables ]
398+ all_relations = list (filter (lambda x : x .get ('source' ).get ('cell' ) in embedding_table_ids or x .get ('target' ).get (
399+ 'cell' ) in embedding_table_ids , relations ))
400+
401+ # get relation table ids, sub embedding table ids
402+ relation_table_ids = []
403+ for r in all_relations :
404+ relation_table_ids .append (r .get ('source' ).get ('cell' ))
405+ relation_table_ids .append (r .get ('target' ).get ('cell' ))
406+ relation_table_ids = list (set (relation_table_ids ))
407+ # get table dict
408+ table_records = session .query (CoreTable ).filter (CoreTable .id .in_ (list (map (int , relation_table_ids )))).all ()
409+ table_dict = {}
410+ for ele in table_records :
411+ table_dict [ele .id ] = ele .table_name
412+
413+ # get lost table ids
414+ lost_table_ids = list (set (relation_table_ids ) - set (embedding_table_ids ))
415+ # get lost table schema and splice it
416+ lost_tables = list (filter (lambda x : x .get ('id' ) in lost_table_ids , all_tables ))
417+ if lost_tables :
418+ for s in lost_tables :
419+ schema_str += s .get ('schema_table' )
420+
421+ # get field dict
422+ relation_field_ids = []
423+ for relation in all_relations :
424+ relation_field_ids .append (relation .get ('source' ).get ('port' ))
425+ relation_field_ids .append (relation .get ('target' ).get ('port' ))
426+ relation_field_ids = list (set (relation_field_ids ))
427+ field_records = session .query (CoreField ).filter (CoreField .id .in_ (list (map (int , relation_field_ids )))).all ()
428+ field_dict = {}
429+ for ele in field_records :
430+ field_dict [ele .id ] = ele .field_name
431+
432+ schema_str += '【Foreign keys】\n '
433+ for ele in all_relations :
434+ schema_str += f"{ table_dict .get (int (ele .get ('source' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('source' ).get ('port' )))} ={ table_dict .get (int (ele .get ('target' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('target' ).get ('port' )))} \n "
384435
385- # todo 外键
386- for s in tables :
387- schema_str += s
388436 return schema_str
0 commit comments