Skip to content

Commit 291b7a3

Browse files
committed
feat: predict data
1 parent 2d7dc6d commit 291b7a3

File tree

10 files changed

+315
-75
lines changed

10 files changed

+315
-75
lines changed

backend/apps/chat/api/chat.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,14 @@ def run_task():
169169
return StreamingResponse(run_task(), media_type="text/event-stream")
170170

171171

172-
@router.post("/record/{chart_record_id}/analysis")
173-
async def analysis(session: SessionDep, current_user: CurrentUser, chart_record_id: int):
172+
@router.post("/record/{chart_record_id}/{action_type}")
173+
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chart_record_id: int, action_type: str):
174+
if action_type != 'analysis' and action_type != 'predict':
175+
raise HTTPException(
176+
status_code=404,
177+
detail="Not Found"
178+
)
179+
174180
record = session.query(ChatRecord).get(chart_record_id)
175181
if not record:
176182
raise HTTPException(
@@ -217,13 +223,29 @@ async def analysis(session: SessionDep, current_user: CurrentUser, chart_record_
217223

218224
def run_task():
219225
try:
220-
# generate analysis
221-
analysis_res = llm_service.generate_analysis(session=session)
222-
for chunk in analysis_res:
223-
yield orjson.dumps({'content': chunk, 'type': 'analysis-result'}).decode() + '\n\n'
224-
yield orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n'
226+
if action_type == 'analysis':
227+
# generate analysis
228+
analysis_res = llm_service.generate_analysis(session=session)
229+
for chunk in analysis_res:
230+
yield orjson.dumps({'content': chunk, 'type': 'analysis-result'}).decode() + '\n\n'
231+
yield orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n'
232+
233+
yield orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n'
234+
235+
elif action_type == 'predict':
236+
# generate predict
237+
analysis_res = llm_service.generate_predict(session=session)
238+
full_text = ''
239+
for chunk in analysis_res:
240+
yield orjson.dumps({'content': chunk, 'type': 'predict-result'}).decode() + '\n\n'
241+
full_text += chunk
242+
yield orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n'
243+
244+
_data = llm_service.check_save_predict_data(session=session, res=full_text)
245+
yield orjson.dumps({'type': 'predict', 'content': _data}).decode() + '\n\n'
246+
247+
yield orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'
225248

226-
yield orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n'
227249

228250
except Exception as e:
229251
traceback.print_exc()

backend/apps/chat/curd/chat.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
5959
record_list = session.query(ChatRecord).options(
6060
load_only(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
6161
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, ChatRecord.data,
62-
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict, ChatRecord.finish,
63-
ChatRecord.error, ChatRecord.run_time)).filter(
62+
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
63+
ChatRecord.predict_data, ChatRecord.finish, ChatRecord.error, ChatRecord.run_time)).filter(
6464
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
6565

6666
chat_info.records = record_list
@@ -178,12 +178,13 @@ def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, a
178178

179179

180180
def save_full_predict_message_and_answer(session: SessionDep, record_id: int, answer: str,
181-
full_message: str) -> ChatRecord:
181+
full_message: str, data: str) -> ChatRecord:
182182
if not record_id:
183183
raise Exception("Record id cannot be None")
184184
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
185185
record.full_predict_message = full_message
186186
record.predict = answer
187+
record.predict_data = data
187188

188189
result = ChatRecord(**record.model_dump())
189190

@@ -254,6 +255,23 @@ def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord:
254255
return result
255256

256257

258+
def save_predict_data(session: SessionDep, record_id: int, data: str) -> ChatRecord:
259+
if not record_id:
260+
raise Exception("Record id cannot be None")
261+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
262+
record.predict_data = data
263+
264+
result = ChatRecord(**record.model_dump())
265+
266+
session.add(record)
267+
session.flush()
268+
session.refresh(record)
269+
270+
session.commit()
271+
272+
return result
273+
274+
257275
def save_error_message(session: SessionDep, record_id: int, message: str) -> ChatRecord:
258276
if not record_id:
259277
raise Exception("Record id cannot be None")

backend/apps/chat/models/chat_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ChatRecord(SQLModel, table=True):
4040
chart: str = Field(sa_column=Column(Text, nullable=True))
4141
analysis: str = Field(sa_column=Column(Text, nullable=True))
4242
predict: str = Field(sa_column=Column(Text, nullable=True))
43+
predict_data: str = Field(sa_column=Column(Text, nullable=True))
4344
full_sql_message: str = Field(sa_column=Column(Text, nullable=True))
4445
full_chart_message: str = Field(sa_column=Column(Text, nullable=True))
4546
full_analysis_message: str = Field(sa_column=Column(Text, nullable=True))

backend/apps/chat/task/llm.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_llm_config
1111
from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \
1212
save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \
13-
finish_record, save_full_analysis_message_and_answer
13+
finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data
1414
from apps.chat.models.chat_model import ChatQuestion, ChatRecord
1515
from apps.datasource.models.datasource import CoreDatasource
1616
from apps.db.db import exec_sql
@@ -112,7 +112,7 @@ def get_record(self):
112112
def set_record(self, record: ChatRecord):
113113
self.record = record
114114

115-
def generate_analysis(self, session: SessionDep):
115+
def get_fields_from_chart(self):
116116
chart_info = orjson.loads(self.record.chart)
117117
fields = []
118118
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
@@ -129,6 +129,10 @@ def generate_analysis(self, session: SessionDep):
129129
if column.get('value') != column.get('name'):
130130
column_str = column_str + '(' + column.get('name') + ')'
131131
fields.append(column_str)
132+
return fields
133+
134+
def generate_analysis(self, session: SessionDep):
135+
fields = self.get_fields_from_chart()
132136

133137
self.chat_question.fields = orjson.dumps(fields).decode()
134138
self.chat_question.data = orjson.dumps(orjson.loads(self.record.data).get('data')).decode()
@@ -169,6 +173,49 @@ def generate_analysis(self, session: SessionDep):
169173
in
170174
analysis_msg]).decode())
171175

176+
def generate_predict(self, session: SessionDep):
177+
fields = self.get_fields_from_chart()
178+
179+
self.chat_question.fields = orjson.dumps(fields).decode()
180+
self.chat_question.data = orjson.dumps(orjson.loads(self.record.data).get('data')).decode()
181+
predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
182+
predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question()))
183+
predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question()))
184+
185+
history_msg = []
186+
if self.record.full_predict_message and self.record.full_predict_message.strip() != '':
187+
history_msg = orjson.loads(self.record.full_predict_message)
188+
189+
self.record = save_full_predict_message_and_answer(session=session, record_id=self.record.id, answer='',
190+
data='',
191+
full_message=orjson.dumps(history_msg +
192+
[{'type': msg.type,
193+
'content': msg.content} for msg
194+
in
195+
predict_msg]).decode())
196+
197+
full_predict_text = ''
198+
res = self.llm.stream(predict_msg)
199+
for chunk in res:
200+
print(chunk)
201+
if isinstance(chunk, dict):
202+
full_predict_text += chunk['content']
203+
yield chunk['content']
204+
continue
205+
if isinstance(chunk, AIMessageChunk):
206+
full_predict_text += chunk.content
207+
yield chunk.content
208+
continue
209+
210+
predict_msg.append(AIMessage(full_predict_text))
211+
self.record = save_full_predict_message_and_answer(session=session, record_id=self.record.id,
212+
answer=full_predict_text, data='',
213+
full_message=orjson.dumps(history_msg +
214+
[{'type': msg.type,
215+
'content': msg.content} for msg
216+
in
217+
predict_msg]).decode())
218+
172219
def generate_sql(self, session: SessionDep):
173220
# append current question
174221
self.sql_message.append(HumanMessage(self.chat_question.sql_user_question()))
@@ -274,6 +321,17 @@ def check_save_chart(self, session: SessionDep, res: str) -> Dict[str, Any]:
274321

275322
return chart
276323

324+
def check_save_predict_data(self, session: SessionDep, res: str) -> Dict[str, Any]:
325+
326+
json_str = extract_nested_json(res)
327+
328+
if not json_str:
329+
json_str = ''
330+
331+
save_predict_data(session=session, record_id=self.record.id, data=json_str)
332+
333+
return json_str
334+
277335
def save_error(self, session: SessionDep, message: str):
278336
return save_error_message(session=session, record_id=self.record.id, message=message)
279337

backend/template.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ template:
134134
predict:
135135
system: |
136136
### 说明:
137-
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以json格式给你一组数据,你帮我预测之后1-2个周期的数据,用json格式返回。
137+
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以json格式给你一组数据,你帮我预测之后1-2个周期的数据,并将预测数据拼接在原始数据后,用json格式返回。
138+
```json
139+
140+
无法预测或者不支持预测的数据请直接返回:"抱歉,该数据无法进行预测。"
138141
139142
user: |
140143
### 请使用 i18n: {lang} 对应的语言输出你的结果

frontend/src/api/chat.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export class ChatRecord {
3838
chart?: string
3939
analysis?: string
4040
predict?: string
41+
predict_data?: string
4142
finish?: boolean = false
4243
error?: string
4344
run_time: number = 0
@@ -56,6 +57,7 @@ export class ChatRecord {
5657
chart: string | undefined,
5758
analysis: string | undefined,
5859
predict: string | undefined,
60+
predict_data: string | undefined,
5961
finish: boolean,
6062
error: string | undefined,
6163
run_time: number
@@ -73,6 +75,7 @@ export class ChatRecord {
7375
chart?: string,
7476
analysis?: string,
7577
predict?: string,
78+
predict_data?: string,
7679
finish?: boolean,
7780
error?: string,
7881
run_time?: number
@@ -89,6 +92,7 @@ export class ChatRecord {
8992
this.chart = chart
9093
this.analysis = analysis
9194
this.predict = predict
95+
this.predict_data = predict_data
9296
this.finish = finish
9397
this.error = error
9498
this.run_time = run_time ?? 0
@@ -207,6 +211,7 @@ function toChatRecord(data?: any): ChatRecord | undefined {
207211
data.chart,
208212
data.analysis,
209213
data.predict,
214+
data.predict_data,
210215
data.finish,
211216
data.error,
212217
data.run_time

frontend/src/views/chat/ChatAnswer.vue

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
<script setup lang="ts">
22
import type { ChatMessage } from '@/api/chat.ts'
3-
import { computed, nextTick, ref } from 'vue'
3+
import { computed, ref } from 'vue'
44
import { Loading } from '@element-plus/icons-vue'
5-
import ChartComponent from './component/ChartComponent.vue'
65
import MdComponent from './component/MdComponent.vue'
6+
import DisplayChartBlock from './component/DisplayChartBlock.vue'
77
import type { ChartTypes } from '@/views/chat/component/BaseChart.ts'
88
import { ArrowDown } from '@element-plus/icons-vue'
99
import ICON_BAR from '@/assets/svg/chart/bar.svg'
@@ -70,25 +70,6 @@ const chartObject = computed<{
7070
return {}
7171
})
7272
73-
const xAxis = computed(() => {
74-
if (chartObject.value?.axis?.x) {
75-
return [chartObject.value.axis.x]
76-
}
77-
return []
78-
})
79-
const yAxis = computed(() => {
80-
if (chartObject.value?.axis?.y) {
81-
return [chartObject.value.axis.y]
82-
}
83-
return []
84-
})
85-
const series = computed(() => {
86-
if (chartObject.value?.axis?.series) {
87-
return [chartObject.value.axis.series]
88-
}
89-
return []
90-
})
91-
9273
const currentChartType = ref<ChartTypes | undefined>(undefined)
9374
9475
const chartType = computed<ChartTypes>({
@@ -163,10 +144,7 @@ const currentChartTypeIcon = computed(() => {
163144
const chartRef = ref()
164145
165146
function onTypeChange() {
166-
nextTick(() => {
167-
chartRef.value?.destroyChart()
168-
chartRef.value?.renderChart()
169-
})
147+
chartRef.value?.onTypeChange()
170148
}
171149
</script>
172150

@@ -250,31 +228,21 @@ function onTypeChange() {
250228
</div>
251229
</template>
252230
<template v-else-if="settings.type === 'chart'">
253-
<div>
254-
<div v-if="message.record.chart" class="chart-base-container">
255-
<div>
256-
<ChartComponent
257-
v-if="message.record.id"
258-
:id="message.record.id"
259-
ref="chartRef"
260-
:type="chartType"
261-
:columns="chartObject?.columns"
262-
:x="xAxis"
263-
:y="yAxis"
264-
:series="series"
265-
:data="dataObject.data"
266-
/>
267-
</div>
268-
</div>
269-
</div>
231+
<DisplayChartBlock
232+
:id="message.record.id"
233+
ref="chartRef"
234+
:chart-type="chartType"
235+
:message="message"
236+
:data="dataObject.data"
237+
/>
270238
<div v-if="message.record.error" style="color: red">
271239
{{ message.record.error }}
272240
</div>
273241
</template>
274242
</div>
275243
</template>
276244
</el-container>
277-
<slot name="footer"></slot>
245+
<slot :data="{ id: message.record?.id, chartType: chartType, chartObject: chartObject }"></slot>
278246
</el-container>
279247
</template>
280248

0 commit comments

Comments
 (0)