Skip to content

Commit 530f5dc

Browse files
committed
fix: session in thread
1 parent 1d98de3 commit 530f5dc

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

backend/apps/chat/api/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
107107
)
108108
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
109109

110-
llm_service = LLMService(session, current_user, request_question, current_assistant)
110+
llm_service = LLMService(current_user, request_question, current_assistant)
111111
llm_service.set_record(record)
112112
llm_service.run_recommend_questions_task_async()
113113
except Exception as e:
@@ -135,7 +135,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
135135
"""
136136

137137
try:
138-
llm_service = LLMService(session, current_user, request_question, current_assistant)
138+
llm_service = LLMService(current_user, request_question, current_assistant)
139139
llm_service.init_record()
140140
llm_service.run_task_async()
141141
except Exception as e:
@@ -173,7 +173,7 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
173173
request_question = ChatQuestion(chat_id=record.chat_id, question='')
174174

175175
try:
176-
llm_service = LLMService(session, current_user, request_question, current_assistant)
176+
llm_service = LLMService(current_user, request_question, current_assistant)
177177
llm_service.run_analysis_or_predict_task_async(action_type, record)
178178
except Exception as e:
179179
traceback.print_exc()

backend/apps/chat/task/llm.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from sqlalchemy import and_, cast, or_
1717
from sqlalchemy import select
1818
from sqlalchemy.dialects.postgresql import JSONB
19+
from sqlalchemy.orm import sessionmaker
1920
from sqlbot_xpack.permissions.api.permission import transRecord2DTO
2021
from sqlbot_xpack.permissions.models.ds_permission import DsPermission, PermissionDTO
2122
from sqlbot_xpack.permissions.models.ds_rules import DsRules
23+
from sqlmodel import create_engine, Session
2224

2325
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
2426
from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \
@@ -35,7 +37,7 @@
3537
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
3638
from apps.system.schemas.system_schema import AssistantOutDsSchema
3739
from common.core.config import settings
38-
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
40+
from common.core.deps import CurrentAssistant, CurrentUser
3941
from common.utils.utils import SQLBotLogUtil, extract_nested_json
4042

4143
warnings.filterwarnings("ignore")
@@ -54,7 +56,7 @@ class LLMService:
5456
sql_message: List[Union[BaseMessage, dict[str, Any]]] = []
5557
chart_message: List[Union[BaseMessage, dict[str, Any]]] = []
5658
history_records: List[ChatRecord] = []
57-
session: SessionDep
59+
session: Session
5860
current_user: CurrentUser
5961
current_assistant: Optional[CurrentAssistant] = None
6062
out_ds_instance: Optional[AssistantOutDs] = None
@@ -63,15 +65,18 @@ class LLMService:
6365
chunk_list: List[str] = []
6466
future: Future
6567

66-
def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question: ChatQuestion,
68+
def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
6769
current_assistant: Optional[CurrentAssistant] = None):
6870
self.chunk_list = []
69-
self.session = session
71+
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
72+
session_maker = sessionmaker(bind=engine)
73+
self.session = session_maker()
74+
7075
self.current_user = current_user
7176
self.current_assistant = current_assistant
7277
# chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
7378
chat_id = chat_question.chat_id
74-
chat: Chat = self.session.get(Chat, chat_id)
79+
chat: Chat | None = self.session.get(Chat, chat_id)
7580
if not chat:
7681
raise Exception(f"Chat with id {chat_id} not found")
7782
ds: CoreDatasource | AssistantOutDsSchema | None = None
@@ -375,10 +380,10 @@ def select_datasource(self):
375380
]
376381
""" _ds_list = self.session.exec(select(CoreDatasource).options(
377382
load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """
378-
383+
379384
ignore_auto_select = _ds_list and len(_ds_list) == 1
380385
# ignore auto select ds
381-
386+
382387
if not ignore_auto_select:
383388
_ds_list_dict = []
384389
for _ds in _ds_list:
@@ -391,10 +396,10 @@ def select_datasource(self):
391396
history_msg = orjson.loads(self.record.full_select_datasource_message)
392397

393398
self.record = save_full_select_datasource_message_and_answer(session=self.session, record_id=self.record.id,
394-
answer='',
395-
full_message=orjson.dumps(history_msg +
396-
[{'type': msg.type,
397-
'content': msg.content}
399+
answer='',
400+
full_message=orjson.dumps(history_msg +
401+
[{'type': msg.type,
402+
'content': msg.content}
398403
for msg
399404
in
400405
datasource_msg]).decode())
@@ -461,13 +466,13 @@ def select_datasource(self):
461466

462467
if not ignore_auto_select:
463468
self.record = save_full_select_datasource_message_and_answer(session=self.session, record_id=self.record.id,
464-
answer=orjson.dumps({'content': full_text,
465-
'reasoning_content': full_thinking_text}).decode(),
466-
datasource=_datasource,
467-
engine_type=_engine_type,
468-
full_message=orjson.dumps(history_msg +
469-
[{'type': msg.type,
470-
'content': msg.content}
469+
answer=orjson.dumps({'content': full_text,
470+
'reasoning_content': full_thinking_text}).decode(),
471+
datasource=_datasource,
472+
engine_type=_engine_type,
473+
full_message=orjson.dumps(history_msg +
474+
[{'type': msg.type,
475+
'content': msg.content}
471476
for msg
472477
in
473478
datasource_msg]).decode())
@@ -511,7 +516,6 @@ def generate_sql(self):
511516
[{'type': msg.type, 'content': msg.content} for msg in
512517
self.sql_message]).decode())
513518

514-
515519
def build_table_filter(self, sql: str, filters: list):
516520
filter = json.dumps(filters, ensure_ascii=False)
517521
self.chat_question.sql = sql
@@ -561,7 +565,7 @@ def build_table_filter(self, sql: str, filters: list):
561565
# analysis_msg]).decode())
562566
SQLBotLogUtil.info(full_filter_text)
563567
return full_filter_text
564-
568+
565569
def generate_filter(self, sql: str, tables: List):
566570
table_list = self.session.query(CoreTable).filter(
567571
and_(CoreTable.ds_id == self.ds.id, CoreTable.table_name.in_(tables))
@@ -586,7 +590,7 @@ def generate_filter(self, sql: str, tables: List):
586590
filters.append({"table": table.table_name, "filter": where_str})
587591

588592
return self.build_table_filter(sql=sql, filters=filters)
589-
593+
590594
def generate_assistant_filter(self, sql, tables: List):
591595
ds: AssistantOutDsSchema = self.ds
592596
filters = []
@@ -596,7 +600,7 @@ def generate_assistant_filter(self, sql, tables: List):
596600
if not filters:
597601
return None
598602
return self.build_table_filter(sql=sql, filters=filters)
599-
603+
600604
def generate_chart(self):
601605
# append current question
602606
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question()))
@@ -810,7 +814,8 @@ def run_task(self, in_chat: bool = True):
810814
SQLBotLogUtil.info(full_sql_text)
811815

812816
# todo row permission
813-
if (not self.current_assistant and is_normal_user(self.current_user)) or (self.current_assistant and self.current_assistant.type == 1):
817+
if (not self.current_assistant and is_normal_user(self.current_user)) or (
818+
self.current_assistant and self.current_assistant.type == 1):
814819
sql_json_str = extract_nested_json(full_sql_text)
815820
data = orjson.loads(sql_json_str)
816821

@@ -831,7 +836,7 @@ def run_task(self, in_chat: bool = True):
831836
sql_result = self.generate_assistant_filter(data.get('sql'), data.get('tables'))
832837
else:
833838
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
834-
839+
835840
if sql_result:
836841
SQLBotLogUtil.info(sql_result)
837842
sql = self.check_save_sql(res=sql_result)
@@ -1060,6 +1065,7 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
10601065
except Exception:
10611066
pass
10621067

1068+
10631069
def get_lang_name(lang: str):
10641070
if lang and lang == 'en':
10651071
return '英文'

0 commit comments

Comments
 (0)