@@ -9,7 +9,7 @@ import { Prompts, Sender } from '@ant-design/x'
99import type { SenderRef } from '@ant-design/x/es/sender'
1010import parseThink from '@leaf/parse-think'
1111import { ExportTypes } from '@psych/sheet'
12- import { Space , Tag } from 'antd'
12+ import { Popover , Space , Tag } from 'antd'
1313import type OpenAI from 'openai'
1414import { useEffect , useRef , useState } from 'react'
1515import { flushSync } from 'react-dom'
@@ -25,10 +25,12 @@ import {
2525 useNav ,
2626} from '../../hooks/useNav'
2727import { useStates } from '../../hooks/useStates'
28+ import { shortId , sleep } from '../../lib/utils'
2829import { Funcs } from '../../tools/enum'
2930import { funcsTools } from '../../tools/tools'
3031import type { Variable } from '../../types'
3132import { ALLOWED_INTERPOLATION_METHODS , ALL_VARS_IDENTIFIER } from '../../types'
33+ import { simpleMediationTestCalculator } from '../statistics/SimpleMediatorTest'
3234import { Messages } from './Messages'
3335
3436const GREETTING =
@@ -96,9 +98,10 @@ export function AI() {
9698 const [ input , setInput ] = useState ( '' )
9799 const [ loading , setLoading ] = useState ( false )
98100 const [ showLoading , setShowLoading ] = useState ( false )
99- const [ messages , setMessages ] = useState < OpenAI . ChatCompletionMessageParam [ ] > (
100- [ ] ,
101- )
101+ const [ messages , setMessages ] = useState <
102+ ( OpenAI . ChatCompletionMessageParam & { id : string } ) [ ]
103+ > ( [ ] )
104+ const [ tokenUsage , setTokenUsage ] = useState < number > ( 0 )
102105
103106 // 数据被清除时重置对话
104107 useEffect ( ( ) => {
@@ -117,12 +120,15 @@ export function AI() {
117120 }
118121 const onSubmit = async ( ) => {
119122 abortRef . current = false
120- const old = JSON . parse ( JSON . stringify ( messages ) )
123+ const old = JSON . parse (
124+ JSON . stringify ( messages ) ,
125+ ) as ( OpenAI . ChatCompletionMessageParam & { id : string } ) [ ]
121126 const snapshot = input
122127 try {
123- const user : OpenAI . ChatCompletionUserMessageParam = {
128+ const user : OpenAI . ChatCompletionUserMessageParam & { id : string } = {
124129 role : 'user' ,
125130 content : snapshot ,
131+ id : shortId ( ) ,
126132 }
127133 flushSync ( ( ) => {
128134 setLoading ( true )
@@ -143,7 +149,7 @@ export function AI() {
143149 usableCount : dataRows . length ,
144150 } )
145151 // 初始化消息数组和当前状态
146- let currentMessages : OpenAI . ChatCompletionMessageParam [ ] = [ ...old , user ]
152+ let currentMessages = [ ...old , user ]
147153 let hasToolCall = true
148154 // 使用while循环处理连续的函数调用
149155 while ( hasToolCall ) {
@@ -155,9 +161,19 @@ export function AI() {
155161 }
156162 const stream = await ai . chat . completions . create ( {
157163 model : model ,
158- messages : [ { role : 'system' , content : system } , ...currentMessages ] ,
164+ messages : [
165+ { role : 'system' , content : system } ,
166+ ...( currentMessages . map ( ( message ) =>
167+ Object . fromEntries (
168+ Object . entries ( message ) . filter ( ( [ key ] ) => key !== 'id' ) ,
169+ ) ,
170+ ) as OpenAI . ChatCompletionMessageParam [ ] ) ,
171+ ] ,
159172 stream : true ,
160173 tools : funcsTools ,
174+ stream_options : {
175+ include_usage : true ,
176+ }
161177 } )
162178 if ( abortRef . current ) {
163179 throw new Error ( '已取消本次请求' )
@@ -168,6 +184,10 @@ export function AI() {
168184 if ( abortRef . current ) {
169185 throw new Error ( '已取消本次请求' )
170186 }
187+ if ( chunk . usage ) {
188+ setTokenUsage ( chunk . usage . total_tokens )
189+ break
190+ }
171191 const delta = chunk . choices [ 0 ] . delta
172192 if ( delta . tool_calls ?. length ) {
173193 if ( toolCall ) {
@@ -192,7 +212,7 @@ export function AI() {
192212 setShowLoading ( false )
193213 setMessages ( [
194214 ...currentMessages ,
195- { role : 'assistant' , content : rawResponse } ,
215+ { role : 'assistant' , content : rawResponse , id : shortId ( ) } ,
196216 ] )
197217 } )
198218 }
@@ -202,12 +222,76 @@ export function AI() {
202222 }
203223 // 处理函数调用
204224 if ( toolCall ) {
205- const newMessages : OpenAI . ChatCompletionMessageParam [ ] = [
206- { role : 'assistant' , content : '' , tool_calls : [ toolCall ] } ,
207- { role : 'tool' , content : '' , tool_call_id : toolCall . id } ,
225+ const newMessages : ( OpenAI . ChatCompletionMessageParam & {
226+ id : string
227+ } ) [ ] = [
228+ {
229+ role : 'assistant' ,
230+ content : '' ,
231+ tool_calls : [ toolCall ] ,
232+ id : shortId ( ) ,
233+ } ,
234+ {
235+ role : 'tool' ,
236+ content : '' ,
237+ tool_call_id : toolCall . id ,
238+ id : shortId ( ) ,
239+ } ,
208240 ]
209241 try {
210242 switch ( toolCall . function . name ) {
243+ case Funcs . SIMPLE_MEDIATOR_TEST : {
244+ const { x, m, y, B } = JSON . parse ( toolCall . function . arguments )
245+ if (
246+ typeof x !== 'string' ||
247+ typeof m !== 'string' ||
248+ typeof y !== 'string' ||
249+ typeof B !== 'number'
250+ ) {
251+ throw new Error ( '参数错误' )
252+ }
253+ if (
254+ ! dataCols . some ( ( col ) => col . name === x ) ||
255+ ! dataCols . some ( ( col ) => col . name === m ) ||
256+ ! dataCols . some ( ( col ) => col . name === y )
257+ ) {
258+ throw new Error ( '变量名参数错误' )
259+ }
260+ try {
261+ messageApi ?. loading ( '正在处理数据...' , 0 )
262+ await sleep ( )
263+ const timestamp = Date . now ( )
264+ const filteredRows = dataRows . filter ( ( row ) =>
265+ [ x , m , y ] . every (
266+ ( variable ) => typeof row [ variable ] === 'number' ,
267+ ) ,
268+ )
269+ const xData = filteredRows . map ( ( row ) => row [ x ] ) as number [ ]
270+ const mData = filteredRows . map ( ( row ) => row [ m ] ) as number [ ]
271+ const yData = filteredRows . map ( ( row ) => row [ y ] ) as number [ ]
272+ const result = simpleMediationTestCalculator ( {
273+ x,
274+ m,
275+ y,
276+ B,
277+ N : filteredRows . length ,
278+ xData,
279+ mData,
280+ yData,
281+ } )
282+ newMessages [ 1 ] . content = `##### 统计结果\n\n${ result } `
283+ messageApi ?. destroy ( )
284+ messageApi ?. success (
285+ `数据处理完成, 用时 ${ Date . now ( ) - timestamp } 毫秒` ,
286+ )
287+ break
288+ } catch ( e ) {
289+ messageApi ?. destroy ( )
290+ throw new Error (
291+ `数据处理失败: ${ e instanceof Error ? e . message : String ( e ) } ` ,
292+ )
293+ }
294+ }
211295 case Funcs . DEFINE_INTERPOLATE : {
212296 const { variable_names, method, reference_variable } =
213297 JSON . parse ( toolCall . function . arguments )
@@ -478,7 +562,10 @@ export function AI() {
478562 } else {
479563 // 如果没有工具调用,处理普通响应
480564 const { content } = parseThink ( rawResponse )
481- setMessages ( [ ...currentMessages , { role : 'assistant' , content } ] )
565+ setMessages ( [
566+ ...currentMessages ,
567+ { role : 'assistant' , content, id : shortId ( ) } ,
568+ ] )
482569 hasToolCall = false // 结束循环
483570 }
484571 }
@@ -638,10 +725,19 @@ export function AI() {
638725 const { SendButton, LoadingButton, ClearButton } = info . components
639726 return (
640727 < Space size = 'small' >
728+ < Popover
729+ trigger = { [ 'hover' , 'click' ] }
730+ content = { < span >
731+ 上次 Tokens 使用量< Tag style = { { marginLeft : '0.3rem' , marginRight : '0' } } > { tokenUsage } </ Tag >
732+ </ span > }
733+ >
734+ < InfoCircleOutlined />
735+ </ Popover >
641736 < ClearButton
642737 disabled = { loading || disabled || ! messages . length }
643738 onClick = { ( ) => {
644739 setInput ( '' )
740+ setTokenUsage ( 0 )
645741 setMessages ( [ ] )
646742 messageApi ?. success ( '已清空历史对话' )
647743 } }
0 commit comments