Skip to content

Commit 3ecf2a0

Browse files
committed
feat: add SQL examples in prompt
1 parent a6c067a commit 3ecf2a0

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

backend/apps/chat/task/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
3030
get_last_execute_sql_error
3131
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum
32+
from apps.data_training.curd.data_training import get_training_template
3233
from apps.datasource.crud.datasource import get_table_schema
3334
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
3435
from apps.datasource.models.datasource import CoreDatasource
@@ -935,6 +936,9 @@ def run_task(self, in_chat: bool = True):
935936
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
936937
self.ds.oid if isinstance(self.ds,
937938
CoreDatasource) else 1)
939+
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
940+
self.ds.id, self.ds.oid)
941+
938942
self.init_messages()
939943

940944
# return id

backend/apps/data_training/curd/data_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def save_embeddings(session: Session, ids: List[int]):
204204
"""
205205

206206

207-
def select_terminology_by_question(session: SessionDep, question: str, oid: int, datasource: int):
207+
def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int):
208208
if question.strip() == "":
209209
return []
210210

@@ -303,7 +303,7 @@ def get_training_template(session: SessionDep, question: str, datasource: int, o
303303
oid = 1
304304
if not datasource:
305305
return ''
306-
_results = select_terminology_by_question(session, question, oid, datasource)
306+
_results = select_training_by_question(session, question, oid, datasource)
307307
if _results and len(_results) > 0:
308308
data_training = to_xml_string(_results)
309309
template = get_base_data_training_template().format(data_training=data_training)

0 commit comments

Comments
 (0)