Skip to content

Commit eda163c

Browse files
feat: Execute with custom sql dataset
1 parent 16781d0 commit eda163c

File tree

5 files changed

+71
-35
lines changed

5 files changed

+71
-35
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from apps.template.filter.generator import get_permissions_template
1010
from apps.template.generate_analysis.generator import get_analysis_template
1111
from apps.template.generate_chart.generator import get_chart_template
12+
from apps.template.generate_dynamic.generator import get_dynamic_template
1213
from apps.template.generate_guess_question.generator import get_guess_question_template
1314
from apps.template.generate_predict.generator import get_predict_template
1415
from apps.template.generate_sql.generator import get_sql_template
@@ -107,6 +108,7 @@ class AiModelQuestion(BaseModel):
107108
data: str = ""
108109
lang: str = "简体中文"
109110
filter: str = []
111+
sub_query: Optional[list[dict]] = None
110112

111113
def sql_sys_question(self):
112114
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, lang=self.lang)
@@ -151,7 +153,12 @@ def filter_sys_question(self):
151153

152154
def filter_user_question(self):
153155
return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter)
154-
156+
157+
def dynamic_sys_question(self):
158+
return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine)
159+
160+
def dynamic_user_question(self):
161+
return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query)
155162

156163
class ChatQuestion(AiModelQuestion):
157164
question: str = Body(description='用户提问')

backend/apps/chat/task/llm.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,41 @@ def generate_sql(self):
516516
[{'type': msg.type, 'content': msg.content} for msg in
517517
self.sql_message]).decode())
518518

519+
def generate_with_sub_sql(self, sql, sub_mappings: list):
520+
sub_query = json.dumps(sub_mappings, ensure_ascii=False)
521+
self.chat_question.sql = sql
522+
self.chat_question.sub_query = sub_query
523+
msg: List[Union[BaseMessage, dict[str, Any]]] = []
524+
msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question()))
525+
msg.append(HumanMessage(content=self.chat_question.dynamic_user_question()))
526+
full_thinking_text = ''
527+
full_dynamic_text = ''
528+
res = self.llm.stream(msg)
529+
token_usage = {}
530+
for chunk in res:
531+
SQLBotLogUtil.info(chunk)
532+
reasoning_content_chunk = ''
533+
if 'reasoning_content' in chunk.additional_kwargs:
534+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
535+
if reasoning_content_chunk is None:
536+
reasoning_content_chunk = ''
537+
full_thinking_text += reasoning_content_chunk
538+
full_dynamic_text += chunk.content
539+
get_token_usage(chunk, token_usage)
540+
541+
SQLBotLogUtil.info(full_dynamic_text)
542+
return full_dynamic_text
543+
544+
def generate_assistant_dynamic_sql(self, sql, tables: List):
545+
ds: AssistantOutDsSchema = self.ds
546+
sub_query = []
547+
for table in ds.tables:
548+
if table.name in tables and table.sql:
549+
sub_query.append({"table": table.name, "query": table.sql})
550+
if not sub_query:
551+
return None
552+
return self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query)
553+
519554
def build_table_filter(self, sql: str, filters: list):
520555
filter = json.dumps(filters, ensure_ascii=False)
521556
self.chat_question.sql = sql
@@ -635,27 +670,23 @@ def generate_chart(self):
635670
full_message=orjson.dumps(
636671
[{'type': msg.type, 'content': msg.content} for msg in
637672
self.chart_message]).decode())
638-
639-
def check_save_sql(self, res: str) -> str:
640-
673+
def check_sql(self, res: str) -> tuple[any]:
641674
json_str = extract_nested_json(res)
642-
data = orjson.loads(json_str)
643-
675+
data: dict = orjson.loads(json_str)
644676
sql = ''
645677
message = ''
646-
error = False
647-
648678
if data['success']:
649679
sql = data['sql']
650680
else:
651681
message = data['message']
652-
error = True
653-
654-
if error:
655682
raise Exception(message)
683+
656684
if sql.strip() == '':
657685
raise Exception("SQL query is empty")
658-
686+
return sql, data.get('tables')
687+
688+
def check_save_sql(self, res: str) -> str:
689+
sql, *_ = self.check_sql(res=res)
659690
save_sql(session=self.session, sql=sql, record_id=self.record.id)
660691

661692
self.chat_question.sql = sql
@@ -816,34 +847,27 @@ def run_task(self, in_chat: bool = True):
816847

817848
# filter sql
818849
SQLBotLogUtil.info(full_sql_text)
850+
use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type == 1
819851

820852
# todo row permission
821-
if (not self.current_assistant and is_normal_user(self.current_user)) or (
822-
self.current_assistant and self.current_assistant.type == 1):
823-
sql_json_str = extract_nested_json(full_sql_text)
824-
data = orjson.loads(sql_json_str)
825-
826-
sql = ''
827-
message = ''
828-
error = False
829-
if data['success']:
830-
sql = data['sql']
831-
else:
832-
message = data['message']
833-
error = True
834-
if error:
835-
raise Exception(message)
836-
if sql.strip() == '':
837-
raise Exception("SQL query is empty")
853+
if (not self.current_assistant and is_normal_user(self.current_user)) or use_dynamic_ds:
854+
sql, tables = self.check_sql(res=full_sql_text)
838855

839856
if self.current_assistant:
840-
sql_result = self.generate_assistant_filter(data.get('sql'), data.get('tables'))
857+
dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables)
858+
if dynamic_sql_result:
859+
SQLBotLogUtil.info(dynamic_sql_result)
860+
sql, *_ = self.check_sql(res=dynamic_sql_result)
861+
862+
sql_result = self.generate_assistant_filter(sql, tables)
841863
else:
842-
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
864+
sql_result = self.generate_filter(sql, tables) # maybe no sql and tables
843865

844866
if sql_result:
845867
SQLBotLogUtil.info(sql_result)
846868
sql = self.check_save_sql(res=sql_result)
869+
elif dynamic_sql_result:
870+
sql = self.check_save_sql(res=dynamic_sql_result)
847871
else:
848872
sql = self.check_save_sql(res=full_sql_text)
849873
else:

backend/apps/template/generate_dynamic/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from apps.template.template import get_base_template
2+
3+
4+
def get_dynamic_template():
5+
template = get_base_template()
6+
return template['template']['dynamic_sql']

backend/template.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,13 @@ template:
214214
215215
### 说明:
216216
提供给你一句SQL和一组子查询映射表,你需要将给定的SQL查询中的表名替换为对应的子查询。请严格保持原始SQL的结构不变,只替换表引用部分,生成符合{engine}数据库引擎规范的新SQL语句。
217-
- 原始SQL(标记为`sql`)
218-
- 子查询映射表(标记为`sub_query`,格式为`[{"原表名": "子查询SQL"},...]`)
217+
- 子查询映射表标记为sub_query,格式为[{{"table":"表名","query":"子查询语句"}},...]
219218
你必须遵守以下规则:
220219
- 生成的SQL必须符合{engine}的规范。
221220
- 不要替换原来SQL中的过滤条件。
222221
- 完全匹配表名(注意大小写敏感)。
223222
- 根据子查询语句以及{engine}数据库引擎规范决定是否需要给子查询添加括号包围
224-
- 若子查询包含别名,保留原表名作为别名
223+
- 若原始SQL中原表名有别名则保留原有别名,否则保留原表名作为别名
225224
- 生成SQL时,必须避免关键字冲突。
226225
- 生成的SQL使用JSON格式返回:
227226
{{"success":true,"sql":"生成的SQL语句"}}
@@ -235,5 +234,5 @@ template:
235234
### sql:
236235
{sql}
237236
238-
### 过滤条件:
237+
### 子查询映射表:
239238
{sub_query}

0 commit comments

Comments
 (0)