|
10 | 10 |
|
11 | 11 | from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ |
12 | 12 | delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \ |
13 | | - format_json_data, format_json_list_data |
14 | | -from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData |
| 13 | + format_json_data, format_json_list_data, get_chart_config |
| 14 | +from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj |
15 | 15 | from apps.chat.task.llm import LLMService |
16 | 16 | from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans |
17 | 17 |
|
@@ -42,19 +42,19 @@ def inner(): |
42 | 42 | return await asyncio.to_thread(inner) |
43 | 43 |
|
44 | 44 |
|
45 | | -@router.get("/record/{chart_record_id}/data") |
46 | | -async def chat_record_data(session: SessionDep, chart_record_id: int): |
| 45 | +@router.get("/record/{chat_record_id}/data") |
| 46 | +async def chat_record_data(session: SessionDep, chat_record_id: int): |
47 | 47 | def inner(): |
48 | | - data = get_chat_chart_data(chart_record_id=chart_record_id, session=session) |
| 48 | + data = get_chat_chart_data(chat_record_id=chat_record_id, session=session) |
49 | 49 | return format_json_data(data) |
50 | 50 |
|
51 | 51 | return await asyncio.to_thread(inner) |
52 | 52 |
|
53 | 53 |
|
54 | | -@router.get("/record/{chart_record_id}/predict_data") |
55 | | -async def chat_predict_data(session: SessionDep, chart_record_id: int): |
| 54 | +@router.get("/record/{chat_record_id}/predict_data") |
| 55 | +async def chat_predict_data(session: SessionDep, chat_record_id: int): |
56 | 56 | def inner(): |
57 | | - data = get_chat_predict_data(chart_record_id=chart_record_id, session=session) |
| 57 | + data = get_chat_predict_data(chat_record_id=chat_record_id, session=session) |
58 | 58 | return format_json_list_data(data) |
59 | 59 |
|
60 | 60 | return await asyncio.to_thread(inner) |
@@ -203,17 +203,49 @@ def _err(_e: Exception): |
203 | 203 | return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") |
204 | 204 |
|
205 | 205 |
|
206 | | -@router.post("/excel/export") |
207 | | -async def export_excel(excel_data: ExcelData, trans: Trans): |
208 | | - def inner(): |
| 206 | +@router.get("/record/{chat_record_id}/excel/export") |
| 207 | +async def export_excel(session: SessionDep, chat_record_id: int, trans: Trans): |
| 208 | + chat_record = session.get(ChatRecord, chat_record_id) |
| 209 | + if not chat_record: |
| 210 | + raise HTTPException( |
| 211 | + status_code=500, |
| 212 | + detail=f"ChatRecord with id {chat_record_id} not found" |
| 213 | + ) |
| 214 | + |
| 215 | + is_predict_data = chat_record.predict_record_id is not None |
| 216 | + |
| 217 | + _origin_data = format_json_data(get_chat_chart_data(chat_record_id=chat_record_id, session=session)) |
| 218 | + |
| 219 | + _base_field = _origin_data.get('fields') |
| 220 | + _data = _origin_data.get('data') |
| 221 | + |
| 222 | + if not _data: |
| 223 | + raise HTTPException( |
| 224 | + status_code=500, |
| 225 | + detail=trans("i18n_excel_export.data_is_empty") |
| 226 | + ) |
| 227 | + |
| 228 | + chart_info = get_chart_config(session, chat_record_id) |
209 | 229 |
|
210 | | - if not excel_data.data: |
211 | | - raise HTTPException( |
212 | | - status_code=500, |
213 | | - detail=trans("i18n_excel_export.data_is_empty") |
214 | | - ) |
| 230 | + _title = chart_info.get('title') if chart_info.get('title') else 'Excel' |
| 231 | + |
| 232 | + fields = [] |
| 233 | + if chart_info.get('columns') and len(chart_info.get('columns')) > 0: |
| 234 | + for column in chart_info.get('columns'): |
| 235 | + fields.append(AxisObj(name=column.get('name'), value=column.get('value'))) |
| 236 | + if chart_info.get('axis'): |
| 237 | + for _type in ['x', 'y', 'series']: |
| 238 | + if chart_info.get('axis').get(_type): |
| 239 | + column = chart_info.get('axis').get(_type) |
| 240 | + fields.append(AxisObj(name=column.get('name'), value=column.get('value'))) |
| 241 | + |
| 242 | + _predict_data = [] |
| 243 | + if is_predict_data: |
| 244 | + _predict_data = format_json_list_data(get_chat_predict_data(chat_record_id=chat_record_id, session=session)) |
| 245 | + |
| 246 | + def inner(): |
215 | 247 |
|
216 | | - data, _fields_list, col_formats = LLMService.format_pd_data(excel_data.axis, excel_data.data) |
| 248 | + data, _fields_list, col_formats = LLMService.format_pd_data(fields, _data + _predict_data) |
217 | 249 |
|
218 | 250 | df = pd.DataFrame(data, columns=_fields_list) |
219 | 251 |
|
|
0 commit comments