Skip to content

Commit 95d70c0

Browse files
committed
feat: AI助手现在可以直接操作简单中介效应分析, 并在聊天界面显示结果
1 parent 6f5effa commit 95d70c0

File tree

9 files changed

+277
-55
lines changed

9 files changed

+277
-55
lines changed

src/components/assistant/AI.tsx

Lines changed: 109 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { Prompts, Sender } from '@ant-design/x'
99
import type { SenderRef } from '@ant-design/x/es/sender'
1010
import parseThink from '@leaf/parse-think'
1111
import { ExportTypes } from '@psych/sheet'
12-
import { Space, Tag } from 'antd'
12+
import { Popover, Space, Tag } from 'antd'
1313
import type OpenAI from 'openai'
1414
import { useEffect, useRef, useState } from 'react'
1515
import { flushSync } from 'react-dom'
@@ -25,10 +25,12 @@ import {
2525
useNav,
2626
} from '../../hooks/useNav'
2727
import { useStates } from '../../hooks/useStates'
28+
import { shortId, sleep } from '../../lib/utils'
2829
import { Funcs } from '../../tools/enum'
2930
import { funcsTools } from '../../tools/tools'
3031
import type { Variable } from '../../types'
3132
import { ALLOWED_INTERPOLATION_METHODS, ALL_VARS_IDENTIFIER } from '../../types'
33+
import { simpleMediationTestCalculator } from '../statistics/SimpleMediatorTest'
3234
import { Messages } from './Messages'
3335

3436
const GREETTING =
@@ -96,9 +98,10 @@ export function AI() {
9698
const [input, setInput] = useState('')
9799
const [loading, setLoading] = useState(false)
98100
const [showLoading, setShowLoading] = useState(false)
99-
const [messages, setMessages] = useState<OpenAI.ChatCompletionMessageParam[]>(
100-
[],
101-
)
101+
const [messages, setMessages] = useState<
102+
(OpenAI.ChatCompletionMessageParam & { id: string })[]
103+
>([])
104+
const [tokenUsage, setTokenUsage] = useState<number>(0)
102105

103106
// 数据被清除时重置对话
104107
useEffect(() => {
@@ -117,12 +120,15 @@ export function AI() {
117120
}
118121
const onSubmit = async () => {
119122
abortRef.current = false
120-
const old = JSON.parse(JSON.stringify(messages))
123+
const old = JSON.parse(
124+
JSON.stringify(messages),
125+
) as (OpenAI.ChatCompletionMessageParam & { id: string })[]
121126
const snapshot = input
122127
try {
123-
const user: OpenAI.ChatCompletionUserMessageParam = {
128+
const user: OpenAI.ChatCompletionUserMessageParam & { id: string } = {
124129
role: 'user',
125130
content: snapshot,
131+
id: shortId(),
126132
}
127133
flushSync(() => {
128134
setLoading(true)
@@ -143,7 +149,7 @@ export function AI() {
143149
usableCount: dataRows.length,
144150
})
145151
// 初始化消息数组和当前状态
146-
let currentMessages: OpenAI.ChatCompletionMessageParam[] = [...old, user]
152+
let currentMessages = [...old, user]
147153
let hasToolCall = true
148154
// 使用while循环处理连续的函数调用
149155
while (hasToolCall) {
@@ -155,9 +161,19 @@ export function AI() {
155161
}
156162
const stream = await ai.chat.completions.create({
157163
model: model,
158-
messages: [{ role: 'system', content: system }, ...currentMessages],
164+
messages: [
165+
{ role: 'system', content: system },
166+
...(currentMessages.map((message) =>
167+
Object.fromEntries(
168+
Object.entries(message).filter(([key]) => key !== 'id'),
169+
),
170+
) as OpenAI.ChatCompletionMessageParam[]),
171+
],
159172
stream: true,
160173
tools: funcsTools,
174+
stream_options: {
175+
include_usage: true,
176+
}
161177
})
162178
if (abortRef.current) {
163179
throw new Error('已取消本次请求')
@@ -168,6 +184,10 @@ export function AI() {
168184
if (abortRef.current) {
169185
throw new Error('已取消本次请求')
170186
}
187+
if (chunk.usage) {
188+
setTokenUsage(chunk.usage.total_tokens)
189+
break
190+
}
171191
const delta = chunk.choices[0].delta
172192
if (delta.tool_calls?.length) {
173193
if (toolCall) {
@@ -192,7 +212,7 @@ export function AI() {
192212
setShowLoading(false)
193213
setMessages([
194214
...currentMessages,
195-
{ role: 'assistant', content: rawResponse },
215+
{ role: 'assistant', content: rawResponse, id: shortId() },
196216
])
197217
})
198218
}
@@ -202,12 +222,76 @@ export function AI() {
202222
}
203223
// 处理函数调用
204224
if (toolCall) {
205-
const newMessages: OpenAI.ChatCompletionMessageParam[] = [
206-
{ role: 'assistant', content: '', tool_calls: [toolCall] },
207-
{ role: 'tool', content: '', tool_call_id: toolCall.id },
225+
const newMessages: (OpenAI.ChatCompletionMessageParam & {
226+
id: string
227+
})[] = [
228+
{
229+
role: 'assistant',
230+
content: '',
231+
tool_calls: [toolCall],
232+
id: shortId(),
233+
},
234+
{
235+
role: 'tool',
236+
content: '',
237+
tool_call_id: toolCall.id,
238+
id: shortId(),
239+
},
208240
]
209241
try {
210242
switch (toolCall.function.name) {
243+
case Funcs.SIMPLE_MEDIATOR_TEST: {
244+
const { x, m, y, B } = JSON.parse(toolCall.function.arguments)
245+
if (
246+
typeof x !== 'string' ||
247+
typeof m !== 'string' ||
248+
typeof y !== 'string' ||
249+
typeof B !== 'number'
250+
) {
251+
throw new Error('参数错误')
252+
}
253+
if (
254+
!dataCols.some((col) => col.name === x) ||
255+
!dataCols.some((col) => col.name === m) ||
256+
!dataCols.some((col) => col.name === y)
257+
) {
258+
throw new Error('变量名参数错误')
259+
}
260+
try {
261+
messageApi?.loading('正在处理数据...', 0)
262+
await sleep()
263+
const timestamp = Date.now()
264+
const filteredRows = dataRows.filter((row) =>
265+
[x, m, y].every(
266+
(variable) => typeof row[variable] === 'number',
267+
),
268+
)
269+
const xData = filteredRows.map((row) => row[x]) as number[]
270+
const mData = filteredRows.map((row) => row[m]) as number[]
271+
const yData = filteredRows.map((row) => row[y]) as number[]
272+
const result = simpleMediationTestCalculator({
273+
x,
274+
m,
275+
y,
276+
B,
277+
N: filteredRows.length,
278+
xData,
279+
mData,
280+
yData,
281+
})
282+
newMessages[1].content = `##### 统计结果\n\n${result}`
283+
messageApi?.destroy()
284+
messageApi?.success(
285+
`数据处理完成, 用时 ${Date.now() - timestamp} 毫秒`,
286+
)
287+
break
288+
} catch (e) {
289+
messageApi?.destroy()
290+
throw new Error(
291+
`数据处理失败: ${e instanceof Error ? e.message : String(e)}`,
292+
)
293+
}
294+
}
211295
case Funcs.DEFINE_INTERPOLATE: {
212296
const { variable_names, method, reference_variable } =
213297
JSON.parse(toolCall.function.arguments)
@@ -478,7 +562,10 @@ export function AI() {
478562
} else {
479563
// 如果没有工具调用,处理普通响应
480564
const { content } = parseThink(rawResponse)
481-
setMessages([...currentMessages, { role: 'assistant', content }])
565+
setMessages([
566+
...currentMessages,
567+
{ role: 'assistant', content, id: shortId() },
568+
])
482569
hasToolCall = false // 结束循环
483570
}
484571
}
@@ -638,10 +725,19 @@ export function AI() {
638725
const { SendButton, LoadingButton, ClearButton } = info.components
639726
return (
640727
<Space size='small'>
728+
<Popover
729+
trigger={['hover', 'click']}
730+
content={<span>
731+
上次 Tokens 使用量<Tag style={{ marginLeft: '0.3rem', marginRight: '0' }}>{tokenUsage}</Tag>
732+
</span>}
733+
>
734+
<InfoCircleOutlined />
735+
</Popover>
641736
<ClearButton
642737
disabled={loading || disabled || !messages.length}
643738
onClick={() => {
644739
setInput('')
740+
setTokenUsage(0)
645741
setMessages([])
646742
messageApi?.success('已清空历史对话')
647743
}}

src/components/assistant/Messages.tsx

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import type {
1313
} from 'openai/resources/index.mjs'
1414
import { useEffect, useRef } from 'react'
1515
import { useStates } from '../../hooks/useStates'
16-
import { shortId } from '../../lib/utils'
16+
import { Result } from '../widgets/Result'
1717
import { ToolCall } from './ToolCall'
1818

1919
export function Messages({
@@ -24,12 +24,12 @@ export function Messages({
2424
setMessages,
2525
loading,
2626
}: {
27-
messages: ChatCompletionMessageParam[]
27+
messages: (ChatCompletionMessageParam & { id: string })[]
2828
greeting: string
2929
showLoading: boolean
3030
setInput: React.Dispatch<React.SetStateAction<string>>
3131
setMessages: React.Dispatch<
32-
React.SetStateAction<ChatCompletionMessageParam[]>
32+
React.SetStateAction<(ChatCompletionMessageParam & { id: string })[]>
3333
>
3434
loading: boolean
3535
}) {
@@ -42,9 +42,14 @@ export function Messages({
4242
}, [messages])
4343

4444
const messagesToShow = [
45-
{ role: 'assistant', content: greeting },
46-
...messages.filter((message) => message.role !== 'tool'),
47-
...(showLoading ? [{ role: 'assistant', content: '__loading__' }] : []),
45+
{ role: 'assistant', content: greeting, id: 'messages_greeting' },
46+
...messages.filter(
47+
(message) =>
48+
message.role !== 'tool' ||
49+
(typeof message.content === 'string' &&
50+
message.content.startsWith('##### 统计结果')),
51+
),
52+
...(showLoading ? [{ role: 'assistant', content: '__loading__', id: 'messages_loading' }] : []),
4853
]
4954

5055
return (
@@ -57,12 +62,37 @@ export function Messages({
5762
.tool_calls
5863
return (
5964
<Bubble
60-
key={shortId()}
65+
key={message.id}
6166
className='w-full'
6267
placement={message.role === 'user' ? 'end' : 'start'}
6368
content={
6469
tool_calls?.length ? (
6570
<ToolCall toolCall={tool_calls[0]} />
71+
) : message.role === 'tool' ? (
72+
<div className='overflow-hidden'>
73+
<div className='w-full'>
74+
<Result result={message.content as string} fitHeight />
75+
</div>
76+
<div className='w-full mt-2'>
77+
<Button
78+
block
79+
autoInsertSpace={false}
80+
onClick={() => {
81+
navigator.clipboard
82+
.writeText(message.content as string)
83+
.then(() => messageApi?.success('已复制结果到剪贴板'))
84+
.catch((e) =>
85+
messageApi?.error(
86+
`复制失败: ${e instanceof Error ? e.message : String(e)}`,
87+
),
88+
)
89+
}}
90+
>
91+
复制结果的 Markdown 文本
92+
</Button>
93+
</div>
94+
<hr className='w-dvw opacity-0 h-0 p-0 m-0' />
95+
</div>
6696
) : (
6797
<Typography>
6898
<div

src/components/assistant/ToolCall.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import type OpenAI from 'openai'
22
import { useState } from 'react'
3+
import { DefaultTool } from '../../tools/components/DefaultTool'
34
import { ExportDataTool } from '../../tools/components/data/ExportDataTool'
45
import { NavToPageTool } from '../../tools/components/nav/NavToPageTool'
56
import { ApplyFilterTool } from '../../tools/components/variable/ApplyFilterTool'
@@ -11,6 +12,7 @@ import { CreateSubVarTool } from '../../tools/components/variable/CreateSubVarTo
1112
import { DefineInterpolateTool } from '../../tools/components/variable/DefineInterpolateTool'
1213
import { DefineMissingValueTool } from '../../tools/components/variable/DefineMissingValueTool'
1314
import { Funcs } from '../../tools/enum'
15+
import { funcsLabel } from '../../tools/tools'
1416
import type {
1517
ALLOWED_DISCRETE_METHODS,
1618
ALLOWED_INTERPOLATION_METHODS,
@@ -203,7 +205,10 @@ export function ToolCall({
203205
break
204206
}
205207
default: {
206-
throw new Error(`未知函数 (${toolCall.function.name})`)
208+
element = (
209+
<DefaultTool label={funcsLabel.get(name as Funcs) ?? '未知函数'} />
210+
)
211+
break
207212
}
208213
}
209214
return <div className='flex flex-col gap-3'>{element}</div>

0 commit comments

Comments
 (0)