Skip to content

Commit 2b6d5b8

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents 63092e3 + 0ea06a7 commit 2b6d5b8

File tree

11 files changed

+240
-17
lines changed

11 files changed

+240
-17
lines changed

Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Build sqlbot
2+
FROM ghcr.io/1panel-dev/maxkb-vector-model:v1.0.1 AS vector-model
23
FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-base:latest AS sqlbot-builder
34

45
# Set build environment variables
@@ -58,15 +59,16 @@ COPY start.sh /opt/sqlbot/app/start.sh
5859
COPY g2-ssr/*.ttf /usr/share/fonts/truetype/liberation/
5960
COPY --from=sqlbot-builder ${SQLBOT_HOME} ${SQLBOT_HOME}
6061
COPY --from=ssr-builder /app /opt/sqlbot/g2-ssr
62+
COPY --from=vector-model /opt/maxkb/app/model /opt/sqlbot/models
6163

6264
WORKDIR ${SQLBOT_HOME}/app
6365

6466
RUN mkdir -p /opt/sqlbot/images /opt/sqlbot/g2-ssr
6567

66-
EXPOSE 3000 8000
68+
EXPOSE 3000 8000 8001
6769

6870
# Add health check
6971
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
7072
CMD curl -f http://localhost:8000 || exit 1
7173

72-
ENTRYPOINT ["sh", "start.sh"]
74+
ENTRYPOINT ["sh", "start.sh"]

backend/apps/ai_model/embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os.path
12
import threading
23
from typing import Optional
34

@@ -14,7 +15,9 @@ class EmbeddingModelInfo(BaseModel):
1415
device: str = 'cpu'
1516

1617

17-
local_embedding_model = EmbeddingModelInfo(folder=settings.LOCAL_MODEL_PATH, name=settings.DEFAULT_EMBEDDING_MODEL)
18+
local_embedding_model = EmbeddingModelInfo(folder=settings.LOCAL_MODEL_PATH,
19+
name=os.path.join(settings.LOCAL_MODEL_PATH, 'embedding',
20+
"shibing624_text2vec-base-chinese"))
1821

1922
_lock = threading.Lock()
2023
locks = {}

backend/apps/chat/api/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
186186
detail=f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it"
187187
)
188188

189-
request_question = ChatQuestion(chat_id=record.chat_id, question='')
189+
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question)
190190

191191
try:
192192
llm_service = LLMService(current_user, request_question, current_assistant)

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,11 @@ class AiModelQuestion(BaseModel):
170170
lang: str = "简体中文"
171171
filter: str = []
172172
sub_query: Optional[list[dict]] = None
173+
terminologies: str = ""
173174

174175
def sql_sys_question(self):
175176
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
176-
lang=self.lang)
177+
lang=self.lang, terminologies=self.terminologies)
177178

178179
def sql_user_question(self):
179180
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
@@ -186,7 +187,7 @@ def chart_user_question(self):
186187
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule)
187188

188189
def analysis_sys_question(self):
189-
return get_analysis_template()['system'].format(lang=self.lang)
190+
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies)
190191

191192
def analysis_user_question(self):
192193
return get_analysis_template()['user'].format(fields=self.fields, data=self.data)

backend/apps/chat/task/llm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from apps.db.db import exec_sql, get_version
3232
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
3333
from apps.system.schemas.system_schema import AssistantOutDsSchema
34+
from apps.terminology.curd.terminology import get_terminology_template
3435
from common.core.config import settings
3536
from common.core.deps import CurrentAssistant, CurrentUser
3637
from common.error import SingleMessageError
@@ -124,8 +125,6 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
124125
llm_instance = LLMFactory.create_llm(self.config)
125126
self.llm = llm_instance.llm
126127

127-
self.init_messages()
128-
129128
def is_running(self, timeout=0.5):
130129
try:
131130
r = concurrent.futures.wait([self.future], timeout)
@@ -210,6 +209,9 @@ def generate_analysis(self):
210209
data = get_chat_chart_data(self.session, self.record.id)
211210
self.chat_question.data = orjson.dumps(data.get('data')).decode()
212211
analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = []
212+
213+
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question)
214+
213215
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
214216
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
215217

@@ -860,6 +862,9 @@ def run_task_cache(self, in_chat: bool = True):
860862

861863
def run_task(self, in_chat: bool = True):
862864
try:
865+
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question)
866+
self.init_messages()
867+
863868
# return id
864869
if in_chat:
865870
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'

backend/apps/template/generate_chart/generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
def get_chart_template():
55
template = get_base_template()
66
return template['template']['chart']
7+
8+
def get_base_terminology_template():
9+
template = get_base_template()
10+
return template['template']['terminology']

backend/apps/terminology/curd/terminology.py

Lines changed: 191 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import datetime
2+
import logging
3+
import traceback
4+
from concurrent.futures import ThreadPoolExecutor
25
from typing import List, Optional
6+
from xml.dom.minidom import parseString
37

4-
from sqlalchemy import and_, or_, select, func, delete, update
8+
import dicttoxml
9+
from sqlalchemy import and_, or_, select, func, delete, update, union
10+
from sqlalchemy import create_engine, text
511
from sqlalchemy.orm import aliased
12+
from sqlalchemy.orm import sessionmaker
613

14+
from apps.ai_model.embedding import EmbeddingModelCache
15+
from apps.template.generate_chart.generator import get_base_terminology_template
716
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
17+
from common.core.config import settings
818
from common.core.deps import SessionDep
919

20+
executor = ThreadPoolExecutor(max_workers=200)
21+
1022

1123
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None):
1224
_list: List[TerminologyInfo] = []
@@ -24,7 +36,7 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
2436
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
2537
matched_ids_subquery = (
2638
select(Terminology.id)
27-
.where(Terminology.word.like(keyword_pattern)) # LIKE查询条件
39+
.where(Terminology.word.ilike(keyword_pattern)) # LIKE查询条件
2840
.subquery()
2941
)
3042

@@ -82,7 +94,6 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
8294
.where(Terminology.id.in_(paginated_parent_ids))
8395
.order_by(Terminology.create_time.desc())
8496
)
85-
print(str(stmt))
8697
else:
8798
parent_ids_subquery = (
8899
select(Terminology.id)
@@ -113,7 +124,6 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
113124
.group_by(Terminology.id, Terminology.word)
114125
.order_by(Terminology.create_time.desc())
115126
)
116-
print(str(stmt))
117127

118128
result = session.execute(stmt)
119129

@@ -145,13 +155,16 @@ def create_terminology(session: SessionDep, info: TerminologyInfo):
145155
_list: List[Terminology] = []
146156
if info.other_words:
147157
for other_word in info.other_words:
158+
if other_word.strip() == "":
159+
continue
148160
_list.append(
149161
Terminology(pid=result.id, word=other_word, create_time=create_time))
150162
session.bulk_save_objects(_list)
151163
session.flush()
152164
session.commit()
153165

154-
# todo embedding
166+
# embedding
167+
run_save_embeddings([result.id])
155168

156169
return result.id
157170

@@ -172,13 +185,16 @@ def update_terminology(session: SessionDep, info: TerminologyInfo):
172185
_list: List[Terminology] = []
173186
if info.other_words:
174187
for other_word in info.other_words:
188+
if other_word.strip() == "":
189+
continue
175190
_list.append(
176191
Terminology(pid=info.id, word=other_word, create_time=create_time))
177192
session.bulk_save_objects(_list)
178193
session.flush()
179194
session.commit()
180195

181-
# todo embedding
196+
# embedding
197+
run_save_embeddings([info.id])
182198

183199
return info.id
184200

@@ -187,3 +203,172 @@ def delete_terminology(session: SessionDep, ids: list[int]):
187203
stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids)))
188204
session.execute(stmt)
189205
session.commit()
206+
207+
208+
def run_save_embeddings(ids: List[int]):
209+
executor.submit(save_embeddings, ids)
210+
211+
212+
def fill_empty_embeddings():
213+
executor.submit(run_fill_empty_embeddings)
214+
215+
216+
def run_fill_empty_embeddings():
217+
if not settings.EMBEDDING_ENABLED:
218+
return
219+
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
220+
session_maker = sessionmaker(bind=engine)
221+
session = session_maker()
222+
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
223+
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
224+
combined_stmt = union(stmt1, stmt2)
225+
results = session.execute(combined_stmt).scalars().all()
226+
save_embeddings(results)
227+
228+
229+
def save_embeddings(ids: List[int]):
230+
if not settings.EMBEDDING_ENABLED:
231+
return
232+
233+
if not ids or len(ids) == 0:
234+
return
235+
try:
236+
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
237+
session_maker = sessionmaker(bind=engine)
238+
session = session_maker()
239+
240+
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
241+
242+
_words_list = [item.word for item in _list]
243+
244+
model = EmbeddingModelCache.get_model()
245+
246+
results = model.embed_documents(_words_list)
247+
248+
for index in range(len(results)):
249+
item = results[index]
250+
stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item)
251+
session.execute(stmt)
252+
session.commit()
253+
254+
except Exception:
255+
traceback.print_exc()
256+
257+
258+
embedding_sql = f"""
259+
SELECT id, pid, word, description, similarity
260+
FROM
261+
(SELECT id, pid, word,
262+
COALESCE(
263+
description,
264+
(SELECT description FROM terminology AS parent WHERE parent.id = child.pid)
265+
) AS description,
266+
( 1 - (embedding <=> :embedding_array) ) AS similarity
267+
FROM terminology AS child
268+
) TEMP
269+
WHERE similarity > {settings.EMBEDDING_SIMILARITY}
270+
ORDER BY similarity DESC
271+
LIMIT {settings.EMBEDDING_TOP_COUNT}
272+
"""
273+
274+
275+
def select_terminology_by_word(session: SessionDep, word: str):
276+
if word.strip() == "":
277+
return []
278+
279+
_list: List[Terminology] = []
280+
281+
stmt = (
282+
select(
283+
Terminology.id,
284+
Terminology.pid,
285+
Terminology.word,
286+
func.coalesce(
287+
Terminology.description,
288+
select(Terminology.description)
289+
.where(and_(Terminology.id == Terminology.pid))
290+
.scalar_subquery()
291+
).label('description')
292+
)
293+
.where(
294+
text(":sentence ILIKE '%' || word || '%'")
295+
)
296+
)
297+
298+
results = session.execute(stmt, {'sentence': word}).fetchall()
299+
300+
for row in results:
301+
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid, description=row.description))
302+
303+
if settings.EMBEDDING_ENABLED:
304+
try:
305+
model = EmbeddingModelCache.get_model()
306+
307+
embedding = model.embed_query(word)
308+
309+
print(embedding_sql)
310+
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding)})
311+
312+
for row in results:
313+
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid, description=row.description))
314+
315+
except Exception:
316+
traceback.print_exc()
317+
318+
_map: dict = {}
319+
_ids: set[int] = set()
320+
for row in _list:
321+
if row.id in _ids:
322+
continue
323+
_ids.add(row.id)
324+
if row.pid:
325+
pid = str(row.pid)
326+
else:
327+
pid = str(row.id)
328+
if _map.get(pid) is None:
329+
_map[pid] = {'words': [], 'description': row.description}
330+
_map[pid]['words'].append(row.word)
331+
332+
_results: list[dict] = []
333+
for key in _map.keys():
334+
_results.append(_map.get(key))
335+
336+
return _results
337+
338+
339+
def get_example():
340+
_obj = {
341+
'terminologies': [
342+
{'words': ['GDP', '国内生产总值'],
343+
'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'},
344+
]
345+
}
346+
return to_xml_string(_obj, 'example')
347+
348+
349+
def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
350+
item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
351+
dicttoxml.LOG.setLevel(logging.ERROR)
352+
xml = dicttoxml.dicttoxml(_dict,
353+
custom_root=root,
354+
item_func=item_name_func,
355+
xml_declaration=False,
356+
encoding='utf-8',
357+
attr_type=False).decode('utf-8')
358+
pretty_xml = parseString(xml).toprettyxml()
359+
360+
if pretty_xml.startswith('<?xml'):
361+
end_index = pretty_xml.find('>') + 1
362+
pretty_xml = pretty_xml[end_index:].lstrip()
363+
364+
return pretty_xml
365+
366+
367+
def get_terminology_template(session: SessionDep, question: str) -> str:
368+
_results = select_terminology_by_word(session, question)
369+
if _results and len(_results) > 0:
370+
terminology = to_xml_string(_results)
371+
template = get_base_terminology_template().format(terminologies=terminology)
372+
return template
373+
else:
374+
return ''

backend/common/core/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
8888

8989
LOCAL_MODEL_PATH: str = '/opt/sqlbot/models'
9090
DEFAULT_EMBEDDING_MODEL: str = 'shibing624/text2vec-base-chinese'
91-
92-
EMBEDDING_SIMILARITY: float = 0.6
93-
EMBEDDING_TOP_COUNT: int = 3
91+
EMBEDDING_ENABLED: bool = True
92+
EMBEDDING_SIMILARITY: float = 0.4
93+
EMBEDDING_TOP_COUNT: int = 5
9494

9595

9696
settings = Settings() # type: ignore

0 commit comments

Comments
 (0)