Skip to content

Commit 090df35

Browse files
committed
feat: Vector retrieval matches tables
1 parent 86b4d38 commit 090df35

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

backend/apps/datasource/crud/table.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ def save_table_embedding(session: SessionDep, ids: List[int]):
4848
if not ids or len(ids) == 0:
4949
return
5050
try:
51-
52-
_list = session.query(CoreTable).filter(and_(CoreTable.id.in_(ids))).all()
53-
54-
table_schema = []
55-
for item in _list:
56-
fields = session.query(CoreField).filter(CoreField.table_id == item.id).all()
51+
SQLBotLogUtil.info('start table embedding')
52+
start_time = time.time()
53+
model = EmbeddingModelCache.get_model()
54+
for id in ids:
55+
table = session.query(CoreTable).filter(CoreTable.id == id).first()
56+
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()
5757

5858
schema_table = ''
59-
schema_table += f"# Table: {item.table_name}"
59+
schema_table += f"# Table: {table.table_name}"
6060
table_comment = ''
61-
if item.custom_comment:
62-
table_comment = item.custom_comment.strip()
61+
if table.custom_comment:
62+
table_comment = table.custom_comment.strip()
6363
if table_comment == '':
6464
schema_table += '\n[\n'
6565
else:
@@ -77,22 +77,14 @@ def save_table_embedding(session: SessionDep, ids: List[int]):
7777
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
7878
schema_table += ",\n".join(field_list)
7979
schema_table += '\n]\n'
80-
table_schema.append(schema_table)
81-
82-
model = EmbeddingModelCache.get_model()
80+
# table_schema.append(schema_table)
81+
emb = model.embed_query(schema_table)
8382

84-
SQLBotLogUtil.info(json.dumps(table_schema))
85-
SQLBotLogUtil.info('start table embedding')
86-
start_time = time.time()
87-
results = model.embed_documents(table_schema)
88-
end_time = time.time()
89-
SQLBotLogUtil.info('table embedding finished in:' + str(end_time - start_time) + 'seconds')
90-
91-
for index in range(len(results)):
92-
item = results[index]
93-
stmt = update(CoreTable).where(and_(CoreTable.id == _list[index].id)).values(embedding=item)
83+
stmt = update(CoreTable).where(and_(CoreTable.id == id)).values(embedding=emb)
9484
session.execute(stmt)
9585
session.commit()
9686

87+
end_time = time.time()
88+
SQLBotLogUtil.info('table embedding finished in:' + str(end_time - start_time) + 'seconds')
9789
except Exception:
9890
traceback.print_exc()

0 commit comments

Comments
 (0)