Skip to content

Commit 57ee918

Browse files
committed
feat(X-Pack): add export Custom Prompt
1 parent 07f1337 commit 57ee918

File tree

14 files changed

+226
-181
lines changed

14 files changed

+226
-181
lines changed

backend/apps/chat/api/chat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj
1515
from apps.chat.task.llm import LLMService
1616
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
17+
from common.utils.data_format import DataFormat
1718

1819
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
1920

@@ -245,9 +246,9 @@ async def export_excel(session: SessionDep, chat_record_id: int, trans: Trans):
245246

246247
def inner():
247248

248-
data_list = LLMService.convert_large_numbers_in_object_array(_data + _predict_data)
249+
data_list = DataFormat.convert_large_numbers_in_object_array(_data + _predict_data)
249250

250-
md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)
251+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)
251252

252253
# data, _fields_list, col_formats = LLMService.format_pd_data(fields, _data + _predict_data)
253254

backend/apps/chat/task/llm.py

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from common.core.db import engine
4646
from common.core.deps import CurrentAssistant, CurrentUser
4747
from common.error import SingleMessageError, SQLBotDBError, ParseSQLResultError, SQLBotDBConnectionError
48+
from common.utils.data_format import DataFormat
4849
from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson
4950

5051
warnings.filterwarnings("ignore")
@@ -1039,7 +1040,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10391040

10401041
result = self.execute_sql(sql=real_execute_sql)
10411042

1042-
_data = self.convert_large_numbers_in_object_array(result.get('data'))
1043+
_data = DataFormat.convert_large_numbers_in_object_array(result.get('data'))
10431044
result["data"] = _data
10441045

10451046
self.save_sql_data(session=_session, data_obj=result)
@@ -1057,15 +1058,15 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10571058
for field in result.get('fields'):
10581059
_column_list.append(AxisObj(name=field, value=field))
10591060

1060-
md_data, _fields_list = self.convert_object_array_for_pandas(_column_list, result.get('data'))
1061+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data'))
10611062

10621063
# data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))
10631064

10641065
if not _data or not _fields_list:
10651066
yield 'The SQL execution result is empty.\n\n'
10661067
else:
10671068
df = pd.DataFrame(_data, columns=_fields_list)
1068-
df_safe = self.safe_convert_to_string(df)
1069+
df_safe = DataFormat.safe_convert_to_string(df)
10691070
markdown_table = df_safe.to_markdown(index=False)
10701071
yield markdown_table + '\n\n'
10711072
else:
@@ -1115,15 +1116,15 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
11151116
_column_list.append(
11161117
AxisObj(name=field if not _fields.get(field) else _fields.get(field), value=field))
11171118

1118-
md_data, _fields_list = self.convert_object_array_for_pandas(_column_list, result.get('data'))
1119+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data'))
11191120

11201121
# data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))
11211122

11221123
if not md_data or not _fields_list:
11231124
yield 'The SQL execution result is empty.\n\n'
11241125
else:
11251126
df = pd.DataFrame(md_data, columns=_fields_list)
1126-
df_safe = self.safe_convert_to_string(df)
1127+
df_safe = DataFormat.safe_convert_to_string(df)
11271128
markdown_table = df_safe.to_markdown(index=False)
11281129
yield markdown_table + '\n\n'
11291130

@@ -1176,120 +1177,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
11761177
self.finish(_session)
11771178
session_maker.remove()
11781179

1179-
@staticmethod
1180-
def safe_convert_to_string(df):
1181-
df_copy = df.copy()
1182-
1183-
def format_value(x):
1184-
if pd.isna(x):
1185-
return ""
1186-
1187-
return "\u200b" + str(x)
1188-
1189-
for col in df_copy.columns:
1190-
df_copy[col] = df_copy[col].apply(format_value)
1191-
1192-
return df_copy
1193-
1194-
@staticmethod
1195-
def convert_large_numbers_in_object_array(obj_array, int_threshold=1e15, float_threshold=1e10):
1196-
"""处理对象数组,将每个对象中的大数字转换为字符串"""
1197-
1198-
def format_float_without_scientific(value):
1199-
"""格式化浮点数,避免科学记数法"""
1200-
if value == 0:
1201-
return "0"
1202-
formatted = f"{value:.15f}"
1203-
if '.' in formatted:
1204-
formatted = formatted.rstrip('0').rstrip('.')
1205-
return formatted
1206-
1207-
def process_object(obj):
1208-
"""处理单个对象"""
1209-
if not isinstance(obj, dict):
1210-
return obj
1211-
1212-
processed_obj = {}
1213-
for key, value in obj.items():
1214-
if isinstance(value, (int, float)):
1215-
# 只转换大数字
1216-
if isinstance(value, int) and abs(value) >= int_threshold:
1217-
processed_obj[key] = str(value)
1218-
elif isinstance(value, float) and (abs(value) >= float_threshold or abs(value) < 1e-6):
1219-
processed_obj[key] = format_float_without_scientific(value)
1220-
else:
1221-
processed_obj[key] = value
1222-
elif isinstance(value, dict):
1223-
# 处理嵌套对象
1224-
processed_obj[key] = process_object(value)
1225-
elif isinstance(value, list):
1226-
# 处理对象中的数组
1227-
processed_obj[key] = [process_item(item) for item in value]
1228-
else:
1229-
processed_obj[key] = value
1230-
return processed_obj
1231-
1232-
def process_item(item):
1233-
"""处理数组中的项目"""
1234-
if isinstance(item, dict):
1235-
return process_object(item)
1236-
return item
1237-
1238-
return [process_item(obj) for obj in obj_array]
12391180

1240-
@staticmethod
1241-
def convert_object_array_for_pandas(column_list: list, data_list: list):
1242-
_fields_list = []
1243-
for field_idx, field in enumerate(column_list):
1244-
_fields_list.append(field.name)
1245-
1246-
md_data = []
1247-
for inner_data in data_list:
1248-
_row = []
1249-
for field_idx, field in enumerate(column_list):
1250-
value = inner_data.get(field.value)
1251-
_row.append(value)
1252-
md_data.append(_row)
1253-
return md_data, _fields_list
1254-
1255-
@staticmethod
1256-
def format_pd_data(column_list: list, data_list: list, col_formats: dict = None):
1257-
# 预处理数据并记录每列的格式类型
1258-
# 格式类型:'text'(文本)、'number'(数字)、'default'(默认)
1259-
_fields_list = []
1260-
1261-
if col_formats is None:
1262-
col_formats = {}
1263-
for field_idx, field in enumerate(column_list):
1264-
_fields_list.append(field.name)
1265-
col_formats[field_idx] = 'default' # 默认不特殊处理
1266-
1267-
data = []
1268-
1269-
for _data in data_list:
1270-
_row = []
1271-
for field_idx, field in enumerate(column_list):
1272-
value = _data.get(field.value)
1273-
if value is not None:
1274-
# 检查是否为数字且需要特殊处理
1275-
if isinstance(value, (int, float)):
1276-
# 整数且超过15位 → 转字符串并标记为文本列
1277-
if isinstance(value, int) and len(str(abs(value))) > 15:
1278-
value = str(value)
1279-
col_formats[field_idx] = 'text'
1280-
# 小数且超过15位有效数字 → 转字符串并标记为文本列
1281-
elif isinstance(value, float):
1282-
decimal_str = format(value, '.16f').rstrip('0').rstrip('.')
1283-
if len(decimal_str) > 15:
1284-
value = str(value)
1285-
col_formats[field_idx] = 'text'
1286-
# 其他数字列标记为数字格式(避免科学记数法)
1287-
elif col_formats[field_idx] != 'text':
1288-
col_formats[field_idx] = 'number'
1289-
_row.append(value)
1290-
data.append(_row)
1291-
1292-
return data, _fields_list, col_formats
12931181

12941182
def run_recommend_questions_task_async(self):
12951183
self.future = executor.submit(self.run_recommend_questions_task_cache)

backend/apps/data_training/api/data_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from fastapi.responses import StreamingResponse
88

99
from apps.chat.models.chat_model import AxisObj
10-
from apps.chat.task.llm import LLMService
1110
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
1211
enable_training, get_all_data_training
1312
from apps.data_training.models.data_training_model import DataTrainingInfo
1413
from common.core.deps import SessionDep, CurrentUser, Trans
14+
from common.utils.data_format import DataFormat
1515

1616
router = APIRouter(tags=["DataTraining"], prefix="/system/data-training")
1717

@@ -53,9 +53,9 @@ async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
5353

5454
@router.get("/export")
5555
async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
56-
word: Optional[str] = Query(None, description="搜索术语(可选)")):
56+
question: Optional[str] = Query(None, description="搜索术语(可选)")):
5757
def inner():
58-
_list = get_all_data_training(session, word, oid=current_user.oid)
58+
_list = get_all_data_training(session, question, oid=current_user.oid)
5959

6060
data_list = []
6161
for obj in _list:
@@ -75,7 +75,7 @@ def inner():
7575
fields.append(
7676
AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))
7777

78-
md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)
78+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)
7979

8080
df = pd.DataFrame(md_data, columns=_fields_list)
8181

backend/apps/terminology/api/terminology.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from fastapi.responses import StreamingResponse
88

99
from apps.chat.models.chat_model import AxisObj
10-
from apps.chat.task.llm import LLMService
1110
from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \
1211
delete_terminology, enable_terminology, get_all_terminology
1312
from apps.terminology.models.terminology_model import TerminologyInfo
1413
from common.core.deps import SessionDep, CurrentUser, Trans
14+
from common.utils.data_format import DataFormat
1515

1616
router = APIRouter(tags=["Terminology"], prefix="/system/terminology")
1717

@@ -62,8 +62,8 @@ def inner():
6262
"word": obj.word,
6363
"other_words": ', '.join(obj.other_words) if obj.other_words else '',
6464
"description": obj.description,
65-
"all_data_sources": 'Y' if obj.specific_ds else 'N',
66-
"datasource": ', '.join(obj.datasource_names) if obj.datasource_names else '',
65+
"all_data_sources": 'N' if obj.specific_ds else 'Y',
66+
"datasource": ', '.join(obj.datasource_names) if obj.datasource_names and obj.specific_ds else '',
6767
}
6868
data_list.append(_data)
6969

@@ -74,7 +74,7 @@ def inner():
7474
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources'), value='datasource'))
7575
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources'), value='all_data_sources'))
7676

77-
md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)
77+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)
7878

7979
df = pd.DataFrame(md_data, columns=_fields_list)
8080

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import pandas as pd
2+
3+
class DataFormat:
4+
@staticmethod
5+
def safe_convert_to_string(df):
6+
df_copy = df.copy()
7+
8+
def format_value(x):
9+
if pd.isna(x):
10+
return ""
11+
12+
return "\u200b" + str(x)
13+
14+
for col in df_copy.columns:
15+
df_copy[col] = df_copy[col].apply(format_value)
16+
17+
return df_copy
18+
19+
@staticmethod
20+
def convert_large_numbers_in_object_array(obj_array, int_threshold=1e15, float_threshold=1e10):
21+
"""处理对象数组,将每个对象中的大数字转换为字符串"""
22+
23+
def format_float_without_scientific(value):
24+
"""格式化浮点数,避免科学记数法"""
25+
if value == 0:
26+
return "0"
27+
formatted = f"{value:.15f}"
28+
if '.' in formatted:
29+
formatted = formatted.rstrip('0').rstrip('.')
30+
return formatted
31+
32+
def process_object(obj):
33+
"""处理单个对象"""
34+
if not isinstance(obj, dict):
35+
return obj
36+
37+
processed_obj = {}
38+
for key, value in obj.items():
39+
if isinstance(value, (int, float)):
40+
# 只转换大数字
41+
if isinstance(value, int) and abs(value) >= int_threshold:
42+
processed_obj[key] = str(value)
43+
elif isinstance(value, float) and (abs(value) >= float_threshold or abs(value) < 1e-6):
44+
processed_obj[key] = format_float_without_scientific(value)
45+
else:
46+
processed_obj[key] = value
47+
elif isinstance(value, dict):
48+
# 处理嵌套对象
49+
processed_obj[key] = process_object(value)
50+
elif isinstance(value, list):
51+
# 处理对象中的数组
52+
processed_obj[key] = [process_item(item) for item in value]
53+
else:
54+
processed_obj[key] = value
55+
return processed_obj
56+
57+
def process_item(item):
58+
"""处理数组中的项目"""
59+
if isinstance(item, dict):
60+
return process_object(item)
61+
return item
62+
63+
return [process_item(obj) for obj in obj_array]
64+
65+
@staticmethod
66+
def convert_object_array_for_pandas(column_list: list, data_list: list):
67+
_fields_list = []
68+
for field_idx, field in enumerate(column_list):
69+
_fields_list.append(field.name)
70+
71+
md_data = []
72+
for inner_data in data_list:
73+
_row = []
74+
for field_idx, field in enumerate(column_list):
75+
value = inner_data.get(field.value)
76+
_row.append(value)
77+
md_data.append(_row)
78+
return md_data, _fields_list
79+
80+
@staticmethod
81+
def format_pd_data(column_list: list, data_list: list, col_formats: dict = None):
82+
# 预处理数据并记录每列的格式类型
83+
# 格式类型:'text'(文本)、'number'(数字)、'default'(默认)
84+
_fields_list = []
85+
86+
if col_formats is None:
87+
col_formats = {}
88+
for field_idx, field in enumerate(column_list):
89+
_fields_list.append(field.name)
90+
col_formats[field_idx] = 'default' # 默认不特殊处理
91+
92+
data = []
93+
94+
for _data in data_list:
95+
_row = []
96+
for field_idx, field in enumerate(column_list):
97+
value = _data.get(field.value)
98+
if value is not None:
99+
# 检查是否为数字且需要特殊处理
100+
if isinstance(value, (int, float)):
101+
# 整数且超过15位 → 转字符串并标记为文本列
102+
if isinstance(value, int) and len(str(abs(value))) > 15:
103+
value = str(value)
104+
col_formats[field_idx] = 'text'
105+
# 小数且超过15位有效数字 → 转字符串并标记为文本列
106+
elif isinstance(value, float):
107+
decimal_str = format(value, '.16f').rstrip('0').rstrip('.')
108+
if len(decimal_str) > 15:
109+
value = str(value)
110+
col_formats[field_idx] = 'text'
111+
# 其他数字列标记为数字格式(避免科学记数法)
112+
elif col_formats[field_idx] != 'text':
113+
col_formats[field_idx] = 'number'
114+
_row.append(value)
115+
data.append(_row)
116+
117+
return data, _fields_list, col_formats

backend/locales/en.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@
6262
},
6363
"i18n_custom_prompt": {
6464
"exists_in_db": "Template name already exists",
65-
"not_exists": "This template does not exist"
65+
"not_exists": "This template does not exist",
66+
"prompt_word_name": "Prompt word name",
67+
"prompt_word_content": "Prompt word content",
68+
"effective_data_sources": "Effective Data Sources",
69+
"all_data_sources": "All Data Sources"
6670
},
6771
"i18n_excel_export": {
6872
"data_is_empty": "Form data is empty, unable to export data"

0 commit comments

Comments
 (0)