Skip to content

Commit f59bc96

Browse files
committed
refactor: Service initiation vectorization task
1 parent c6810d7 commit f59bc96

File tree

6 files changed

+70
-57
lines changed

6 files changed

+70
-57
lines changed

backend/apps/data_training/curd/data_training.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import dicttoxml
88
from sqlalchemy import and_, select, func, delete, update, or_
99
from sqlalchemy import text
10-
from sqlalchemy.orm.session import Session
1110

1211
from apps.ai_model.embedding import EmbeddingModelCache
1312
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining
@@ -160,24 +159,30 @@ def delete_training(session: SessionDep, ids: list[int]):
160159
# executor.submit(run_fill_empty_embeddings)
161160

162161

163-
def run_fill_empty_embeddings(session: Session):
164-
if not settings.EMBEDDING_ENABLED:
165-
return
162+
def run_fill_empty_embeddings(session_maker):
163+
try:
164+
if not settings.EMBEDDING_ENABLED:
165+
return
166166

167-
stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None)))
168-
results = session.execute(stmt).scalars().all()
167+
session = session_maker()
168+
stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None)))
169+
results = session.execute(stmt).scalars().all()
169170

170-
save_embeddings(session, results)
171+
save_embeddings(session_maker, results)
172+
except Exception:
173+
traceback.print_exc()
174+
finally:
175+
session_maker.remove()
171176

172177

173-
def save_embeddings(session: Session, ids: List[int]):
178+
def save_embeddings(session_maker, ids: List[int]):
174179
if not settings.EMBEDDING_ENABLED:
175180
return
176181

177182
if not ids or len(ids) == 0:
178183
return
179184
try:
180-
185+
session = session_maker()
181186
_list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all()
182187

183188
_question_list = [item.question for item in _list]
@@ -194,6 +199,8 @@ def save_embeddings(session: Session, ids: List[int]):
194199

195200
except Exception:
196201
traceback.print_exc()
202+
finally:
203+
session_maker.remove()
197204

198205

199206
embedding_sql = f"""

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from apps.db.engine import get_engine_config, get_engine_conn
1616
from common.core.config import settings
1717
from common.core.deps import SessionDep, CurrentUser, Trans
18-
from apps.datasource.crud.table import run_save_table_embeddings
18+
from common.utils.embedding_threads import run_save_table_embeddings
1919
from common.utils.utils import deepcopy_ignore_extra
2020
from .table import get_tables_by_ds_id
2121
from ..crud.field import delete_field_by_ds_id, update_field

backend/apps/datasource/crud/table.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
11
import json
22
import time
33
import traceback
4-
from concurrent.futures import ThreadPoolExecutor
54
from typing import List
65

76
from sqlalchemy import and_, select, update
8-
from sqlalchemy.orm import sessionmaker
9-
from sqlalchemy.orm.session import Session
107

118
from apps.ai_model.embedding import EmbeddingModelCache
129
from common.core.config import settings
1310
from common.core.deps import SessionDep
1411
from common.utils.utils import SQLBotLogUtil
1512
from ..models.datasource import CoreTable, CoreField
1613

17-
executor = ThreadPoolExecutor(max_workers=200)
18-
19-
from common.core.db import engine
20-
21-
session_maker = sessionmaker(bind=engine)
22-
session = session_maker()
23-
2414

2515
def delete_table_by_ds_id(session: SessionDep, id: int):
2616
session.query(CoreTable).filter(CoreTable.ds_id == id).delete(synchronize_session=False)
@@ -40,22 +30,25 @@ def update_table(session: SessionDep, item: CoreTable):
4030
session.commit()
4131

4232

43-
def run_fill_empty_table_embedding(session: Session):
33+
def run_fill_empty_table_embedding(session_maker):
4434
try:
4535
if not settings.TABLE_EMBEDDING_ENABLED:
4636
return
4737

4838
SQLBotLogUtil.info('get tables')
39+
session = session_maker()
4940
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
5041
results = session.execute(stmt).scalars().all()
5142
SQLBotLogUtil.info('result: ' + str(len(results)))
5243

5344
save_table_embedding(session, results)
5445
except Exception:
5546
traceback.print_exc()
47+
finally:
48+
session_maker.remove()
5649

5750

58-
def save_table_embedding(session: Session, ids: List[int]):
51+
def save_table_embedding(session_maker, ids: List[int]):
5952
if not settings.TABLE_EMBEDDING_ENABLED:
6053
return
6154

@@ -65,6 +58,7 @@ def save_table_embedding(session: Session, ids: List[int]):
6558
SQLBotLogUtil.info('start table embedding')
6659
start_time = time.time()
6760
model = EmbeddingModelCache.get_model()
61+
session = session_maker()
6862
for _id in ids:
6963
table = session.query(CoreTable).filter(CoreTable.id == _id).first()
7064
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()
@@ -102,14 +96,5 @@ def save_table_embedding(session: Session, ids: List[int]):
10296
SQLBotLogUtil.info('table embedding finished in: ' + str(end_time - start_time) + ' seconds')
10397
except Exception:
10498
traceback.print_exc()
105-
106-
107-
def run_save_table_embeddings(ids: List[int]):
108-
executor.submit(save_table_embedding, session, ids)
109-
110-
111-
def fill_empty_table_embeddings():
112-
try:
113-
executor.submit(run_fill_empty_table_embedding, session)
114-
except Exception:
115-
traceback.print_exc()
99+
finally:
100+
session_maker.remove()

backend/apps/terminology/curd/terminology.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import dicttoxml
88
from sqlalchemy import and_, or_, select, func, delete, update, union, text, BigInteger
99
from sqlalchemy.orm import aliased
10-
from sqlalchemy.orm.session import Session
1110

1211
from apps.ai_model.embedding import EmbeddingModelCache
1312
from apps.datasource.models.datasource import CoreDatasource
@@ -407,26 +406,36 @@ def delete_terminology(session: SessionDep, ids: list[int]):
407406
#
408407
# def fill_empty_embeddings():
409408
# executor.submit(run_fill_empty_embeddings)
409+
# from sqlalchemy import create_engine
410+
# from sqlalchemy.orm import sessionmaker,scoped_session
411+
# engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
412+
# session_maker = scoped_session(sessionmaker(bind=engine))
410413

411-
412-
def run_fill_empty_embeddings(session: Session):
413-
if not settings.EMBEDDING_ENABLED:
414-
return
415-
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
416-
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
417-
combined_stmt = union(stmt1, stmt2)
418-
results = session.execute(combined_stmt).scalars().all()
419-
save_embeddings(session, results)
414+
def run_fill_empty_embeddings(session_maker):
415+
try:
416+
if not settings.EMBEDDING_ENABLED:
417+
return
418+
session = session_maker()
419+
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
420+
stmt2 = select(Terminology.pid).where(
421+
and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
422+
combined_stmt = union(stmt1, stmt2)
423+
results = session.execute(combined_stmt).scalars().all()
424+
save_embeddings(session_maker, results)
425+
except Exception:
426+
traceback.print_exc()
427+
finally:
428+
session_maker.remove()
420429

421430

422-
def save_embeddings(session: Session, ids: List[int]):
431+
def save_embeddings(session_maker, ids: List[int]):
423432
if not settings.EMBEDDING_ENABLED:
424433
return
425434

426435
if not ids or len(ids) == 0:
427436
return
428437
try:
429-
438+
session = session_maker()
430439
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
431440

432441
_words_list = [item.word for item in _list]
@@ -443,6 +452,8 @@ def save_embeddings(session: Session, ids: List[int]):
443452

444453
except Exception:
445454
traceback.print_exc()
455+
finally:
456+
session_maker.remove()
446457

447458

448459
embedding_sql = f"""
Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,43 @@
11
from concurrent.futures import ThreadPoolExecutor
22
from typing import List
33

4-
from sqlalchemy import create_engine
5-
from sqlalchemy.orm import sessionmaker
6-
7-
from common.core.config import settings
4+
from sqlalchemy.orm import sessionmaker, scoped_session
85

96
executor = ThreadPoolExecutor(max_workers=200)
107

11-
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
12-
session_maker = sessionmaker(bind=engine)
13-
session = session_maker()
8+
from common.core.db import engine
9+
10+
session_maker = scoped_session(sessionmaker(bind=engine))
11+
12+
13+
# session = session_maker()
1414

1515

1616
def run_save_terminology_embeddings(ids: List[int]):
1717
from apps.terminology.curd.terminology import save_embeddings
18-
executor.submit(save_embeddings, session, ids)
18+
executor.submit(save_embeddings, session_maker, ids)
1919

2020

2121
def fill_empty_terminology_embeddings():
2222
from apps.terminology.curd.terminology import run_fill_empty_embeddings
23-
executor.submit(run_fill_empty_embeddings, session)
23+
executor.submit(run_fill_empty_embeddings, session_maker)
2424

2525

2626
def run_save_data_training_embeddings(ids: List[int]):
2727
from apps.data_training.curd.data_training import save_embeddings
28-
executor.submit(save_embeddings, session, ids)
28+
executor.submit(save_embeddings, session_maker, ids)
2929

3030

3131
def fill_empty_data_training_embeddings():
3232
from apps.data_training.curd.data_training import run_fill_empty_embeddings
33-
executor.submit(run_fill_empty_embeddings, session)
33+
executor.submit(run_fill_empty_embeddings, session_maker)
34+
35+
36+
def run_save_table_embeddings(ids: List[int]):
37+
from apps.datasource.crud.table import save_table_embedding
38+
executor.submit(save_table_embedding, session_maker, ids)
39+
40+
41+
def fill_empty_table_embeddings():
42+
from apps.datasource.crud.table import run_fill_empty_table_embedding
43+
executor.submit(run_fill_empty_table_embedding, session_maker)

backend/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from alembic import command
1414
from apps.api import api_router
15-
from apps.datasource.crud.table import fill_empty_table_embeddings
15+
from common.utils.embedding_threads import fill_empty_table_embeddings
1616
from apps.system.crud.aimodel_manage import async_model_info
1717
from apps.system.crud.assistant import init_dynamic_cors
1818
from apps.system.middleware.auth import TokenMiddleware

0 commit comments

Comments
 (0)