Skip to content

Commit 0eee5e5

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents fc2761c + eda163c commit 0eee5e5

File tree

11 files changed

+140
-46
lines changed

11 files changed

+140
-46
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<p align="center"><img src= "TBD" alt="SQLBot" width="300" /></p>
1+
<p align="center"><img src="https://resource-fit2cloud-com.oss-cn-hangzhou.aliyuncs.com/sqlbot/sqlbot.png" alt="SQLBot" width="300" /></p>
22
<h3 align="center">基于大模型和 RAG 的智能问数系统</h3>
33
<p align="center">
44
<a href="https://www.gnu.org/licenses/gpl-3.0.html#license-text"><img src="https://img.shields.io/github/license/1Panel-dev/SQLBot?color=%231890FF" alt="License: GPL v3"></a>
@@ -16,7 +16,7 @@ TBD
1616

1717
SQLBot 的优势包括:
1818

19-
- **开箱即用**: 只需配置大模型和数据源即可开启问数之旅,通过大模型和 RAG 的结合来实现高质量的 NL2SQL 和 Text2SQL
19+
- **开箱即用**: 只需配置大模型和数据源即可开启问数之旅,通过大模型和 RAG 的结合来实现高质量的 text2sql
2020
- **易于集成**: 支持快速嵌入到第三方业务系统,也支持被 n8n、MaxKB、Dify 、Coze 等 AI 应用开发平台集成调用,让各类应用快速拥有智能问数能力;
2121
- **安全可控**: 提供基于工作空间的资源隔离机制,能够实现细粒度的数据权限控制。
2222

backend/apps/chat/models/chat_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from apps.template.filter.generator import get_permissions_template
1010
from apps.template.generate_analysis.generator import get_analysis_template
1111
from apps.template.generate_chart.generator import get_chart_template
12+
from apps.template.generate_dynamic.generator import get_dynamic_template
1213
from apps.template.generate_guess_question.generator import get_guess_question_template
1314
from apps.template.generate_predict.generator import get_predict_template
1415
from apps.template.generate_sql.generator import get_sql_template
@@ -107,6 +108,7 @@ class AiModelQuestion(BaseModel):
107108
data: str = ""
108109
lang: str = "简体中文"
109110
filter: str = []
111+
sub_query: Optional[list[dict]] = None
110112

111113
def sql_sys_question(self):
112114
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, lang=self.lang)
@@ -151,7 +153,12 @@ def filter_sys_question(self):
151153

152154
def filter_user_question(self):
153155
return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter)
154-
156+
157+
def dynamic_sys_question(self):
158+
return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine)
159+
160+
def dynamic_user_question(self):
161+
return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query)
155162

156163
class ChatQuestion(AiModelQuestion):
157164
question: str = Body(description='用户提问')

backend/apps/chat/task/llm.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from apps.system.schemas.system_schema import AssistantOutDsSchema
3939
from common.core.config import settings
4040
from common.core.deps import CurrentAssistant, CurrentUser
41-
from common.utils.utils import SQLBotLogUtil, extract_nested_json
41+
from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson
4242

4343
warnings.filterwarnings("ignore")
4444

@@ -71,7 +71,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
7171
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
7272
session_maker = sessionmaker(bind=engine)
7373
self.session = session_maker()
74-
74+
self.session.exec = self.session.exec if hasattr(self.session, "exec") else self.session.execute
7575
self.current_user = current_user
7676
self.current_assistant = current_assistant
7777
# chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
@@ -365,7 +365,7 @@ def select_datasource(self):
365365
datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = []
366366
datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question()))
367367
if self.current_assistant:
368-
_ds_list = get_assistant_ds(llm_service=self)
368+
_ds_list = get_assistant_ds(session=self.session, llm_service=self)
369369
else:
370370
oid: str = self.current_user.oid
371371
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
@@ -516,6 +516,41 @@ def generate_sql(self):
516516
[{'type': msg.type, 'content': msg.content} for msg in
517517
self.sql_message]).decode())
518518

519+
def generate_with_sub_sql(self, sql, sub_mappings: list):
520+
sub_query = json.dumps(sub_mappings, ensure_ascii=False)
521+
self.chat_question.sql = sql
522+
self.chat_question.sub_query = sub_query
523+
msg: List[Union[BaseMessage, dict[str, Any]]] = []
524+
msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question()))
525+
msg.append(HumanMessage(content=self.chat_question.dynamic_user_question()))
526+
full_thinking_text = ''
527+
full_dynamic_text = ''
528+
res = self.llm.stream(msg)
529+
token_usage = {}
530+
for chunk in res:
531+
SQLBotLogUtil.info(chunk)
532+
reasoning_content_chunk = ''
533+
if 'reasoning_content' in chunk.additional_kwargs:
534+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
535+
if reasoning_content_chunk is None:
536+
reasoning_content_chunk = ''
537+
full_thinking_text += reasoning_content_chunk
538+
full_dynamic_text += chunk.content
539+
get_token_usage(chunk, token_usage)
540+
541+
SQLBotLogUtil.info(full_dynamic_text)
542+
return full_dynamic_text
543+
544+
def generate_assistant_dynamic_sql(self, sql, tables: List):
545+
ds: AssistantOutDsSchema = self.ds
546+
sub_query = []
547+
for table in ds.tables:
548+
if table.name in tables and table.sql:
549+
sub_query.append({"table": table.name, "query": table.sql})
550+
if not sub_query:
551+
return None
552+
return self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query)
553+
519554
def build_table_filter(self, sql: str, filters: list):
520555
filter = json.dumps(filters, ensure_ascii=False)
521556
self.chat_question.sql = sql
@@ -635,27 +670,23 @@ def generate_chart(self):
635670
full_message=orjson.dumps(
636671
[{'type': msg.type, 'content': msg.content} for msg in
637672
self.chart_message]).decode())
638-
639-
def check_save_sql(self, res: str) -> str:
640-
673+
def check_sql(self, res: str) -> tuple[any]:
641674
json_str = extract_nested_json(res)
642-
data = orjson.loads(json_str)
643-
675+
data: dict = orjson.loads(json_str)
644676
sql = ''
645677
message = ''
646-
error = False
647-
648678
if data['success']:
649679
sql = data['sql']
650680
else:
651681
message = data['message']
652-
error = True
653-
654-
if error:
655682
raise Exception(message)
683+
656684
if sql.strip() == '':
657685
raise Exception("SQL query is empty")
658-
686+
return sql, data.get('tables')
687+
688+
def check_save_sql(self, res: str) -> str:
689+
sql, *_ = self.check_sql(res=res)
659690
save_sql(session=self.session, sql=sql, record_id=self.record.id)
660691

661692
self.chat_question.sql = sql
@@ -716,6 +747,10 @@ def save_error(self, message: str):
716747
return save_error_message(session=self.session, record_id=self.record.id, message=message)
717748

718749
def save_sql_data(self, data_obj: Dict[str, Any]):
750+
data_result = data_obj.get('data')
751+
if data_result:
752+
data_result = prepare_for_orjson(data_result)
753+
data_obj['data'] = data_result
719754
return save_sql_exec_data(session=self.session, record_id=self.record.id,
720755
data=orjson.dumps(data_obj).decode())
721756

@@ -812,34 +847,27 @@ def run_task(self, in_chat: bool = True):
812847

813848
# filter sql
814849
SQLBotLogUtil.info(full_sql_text)
850+
use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type == 1
815851

816852
# todo row permission
817-
if (not self.current_assistant and is_normal_user(self.current_user)) or (
818-
self.current_assistant and self.current_assistant.type == 1):
819-
sql_json_str = extract_nested_json(full_sql_text)
820-
data = orjson.loads(sql_json_str)
821-
822-
sql = ''
823-
message = ''
824-
error = False
825-
if data['success']:
826-
sql = data['sql']
827-
else:
828-
message = data['message']
829-
error = True
830-
if error:
831-
raise Exception(message)
832-
if sql.strip() == '':
833-
raise Exception("SQL query is empty")
853+
if (not self.current_assistant and is_normal_user(self.current_user)) or use_dynamic_ds:
854+
sql, tables = self.check_sql(res=full_sql_text)
834855

835856
if self.current_assistant:
836-
sql_result = self.generate_assistant_filter(data.get('sql'), data.get('tables'))
857+
dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables)
858+
if dynamic_sql_result:
859+
SQLBotLogUtil.info(dynamic_sql_result)
860+
sql, *_ = self.check_sql(res=dynamic_sql_result)
861+
862+
sql_result = self.generate_assistant_filter(sql, tables)
837863
else:
838-
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
864+
sql_result = self.generate_filter(sql, tables) # maybe no sql and tables
839865

840866
if sql_result:
841867
SQLBotLogUtil.info(sql_result)
842868
sql = self.check_save_sql(res=sql_result)
869+
elif dynamic_sql_result:
870+
sql = self.check_save_sql(res=dynamic_sql_result)
843871
else:
844872
sql = self.check_save_sql(res=full_sql_text)
845873
else:

backend/apps/system/api/assistant.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,25 @@
1515
router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")
1616

1717
@router.get("/info/{id}")
18-
async def info(request: Request, response: Response, session: SessionDep, id: int) -> dict:
18+
async def info(request: Request, response: Response, session: SessionDep, id: int) -> AssistantModel:
19+
if not id:
20+
raise Exception('miss assistant id')
1921
db_model = await get_assistant_info(session=session, assistant_id=id)
2022
if not db_model:
2123
raise RuntimeError(f"assistant application not exist")
2224
db_model = AssistantModel.model_validate(db_model)
2325
response.headers["Access-Control-Allow-Origin"] = db_model.domain
2426
origin = request.headers.get("origin") or request.headers.get("referer")
2527
origin = origin.rstrip('/')
26-
""" if origin != db_model.domain:
27-
raise RuntimeError("invalid domain [{origin}]") """
28-
return db_model.model_dump()
28+
if origin != db_model.domain:
29+
raise RuntimeError("invalid domain [{origin}]")
30+
return db_model
2931

3032
@router.get("/validator", response_model=AssistantValidator)
31-
async def info(session: SessionDep, id: str, virtual: Optional[int] = Query(None), online: Optional[bool] = Query(default=False)):
33+
async def validator(session: SessionDep, id: int, virtual: Optional[int] = Query(None)):
34+
if not id:
35+
raise Exception('miss assistant id')
36+
3237
db_model = await get_assistant_info(session=session, assistant_id=id)
3338
if not db_model:
3439
return AssistantValidator()

backend/apps/system/crud/assistant.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ def get_assistant_user(*, id: int):
2929
3030

3131

32-
def get_assistant_ds(llm_service) -> list[dict]:
32+
def get_assistant_ds(session: Session, llm_service) -> list[dict]:
3333
assistant: AssistantHeader = llm_service.current_assistant
34-
session: Session = llm_service.session
3534
type = assistant.type
3635
if type == 0:
3736
configuration = assistant.configuration

backend/apps/template/generate_dynamic/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from apps.template.template import get_base_template
2+
3+
4+
def get_dynamic_template():
5+
template = get_base_template()
6+
return template['template']['dynamic_sql']

backend/common/utils/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import hashlib
23
import inspect
34
import logging
@@ -215,6 +216,20 @@ def critical(msg: str, *args, **kwargs):
215216
logger = SQLBotLogUtil._get_logger()
216217
if logger.isEnabledFor(logging.CRITICAL):
217218
logger._log(logging.CRITICAL, msg, args, **kwargs)
219+
220+
def prepare_for_orjson(data):
221+
if not data:
222+
return data
223+
if isinstance(data, bytes):
224+
return base64.b64encode(data).decode('utf-8')
225+
elif isinstance(data, dict):
226+
return {k: prepare_for_orjson(v) for k, v in data.items()}
227+
elif isinstance(data, (list, tuple)):
228+
return [prepare_for_orjson(item) for item in data]
229+
else:
230+
return data
231+
232+
218233

219234

220235

backend/template.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,31 @@ template:
208208
209209
### 过滤条件:
210210
{filter}
211+
dynamic_sql:
212+
system: |
213+
### 请使用语言:{lang} 回答
214+
215+
### 说明:
216+
提供给你一句SQL和一组子查询映射表,你需要将给定的SQL查询中的表名替换为对应的子查询。请严格保持原始SQL的结构不变,只替换表引用部分,生成符合{engine}数据库引擎规范的新SQL语句。
217+
- 子查询映射表标记为sub_query,格式为[{{"table":"表名","query":"子查询语句"}},...]
218+
你必须遵守以下规则:
219+
- 生成的SQL必须符合{engine}的规范。
220+
- 不要替换原来SQL中的过滤条件。
221+
- 完全匹配表名(注意大小写敏感)。
222+
- 根据子查询语句以及{engine}数据库引擎规范决定是否需要给子查询添加括号包围
223+
- 若原始SQL中原表名有别名则保留原有别名,否则保留原表名作为别名
224+
- 生成SQL时,必须避免关键字冲突。
225+
- 生成的SQL使用JSON格式返回:
226+
{{"success":true,"sql":"生成的SQL语句"}}
227+
- 如果不能生成SQL,回答:
228+
{{"success":false,"message":"无法生成SQL的原因"}}
229+
230+
### 响应, 请直接返回JSON结果:
231+
```json
232+
233+
user: |
234+
### sql:
235+
{sql}
236+
237+
### 子查询映射表:
238+
{sub_query}

frontend/public/assistant.js

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; (function () {
1+
;(function () {
22
window.sqlbot_assistant_handler = window.sqlbot_assistant_handler || {}
33
const defaultData = {
44
id: '1',
@@ -461,8 +461,10 @@
461461
let tempData = Object.assign(defaultData, { id, domain_url })
462462
if (config_json) {
463463
const config = JSON.parse(config_json)
464-
Object.assign(tempData, config)
465-
tempData = Object.assign(tempData, config)
464+
if (config) {
465+
delete config.id
466+
tempData = Object.assign(tempData, config)
467+
}
466468
}
467469
tempData['online'] = online && online.toString().toLowerCase() == 'true'
468470
initsqlbot_assistant(tempData)

0 commit comments

Comments
 (0)