Skip to content

Commit 76fa7e5

Browse files
committed
show citation file names
1 parent c615b97 commit 76fa7e5

File tree

6 files changed

+186
-44
lines changed

6 files changed

+186
-44
lines changed

src/client/components/ChatV2/ChatV2.tsx

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import useLocalStorageState from '../../hooks/useLocalStorageState'
66
import { DEFAULT_MODEL } from '../../../config'
77
import useInfoTexts from '../../hooks/useInfoTexts'
88
import { Message } from '../../types'
9-
import { ResponseStreamValue } from '../../../shared/types'
9+
import { FileCitation, ResponseStreamEventData } from '../../../shared/types'
1010
import useRetryTimeout from '../../hooks/useRetryTimeout'
1111
import { useTranslation } from 'react-i18next'
1212
import { handleCompletionStreamError } from './error'
@@ -21,6 +21,7 @@ import { Close, Settings } from '@mui/icons-material'
2121
import { SettingsModal } from './SettingsModal'
2222
import { Link } from 'react-router-dom'
2323
import { useScrollToBottom } from './useScrollToBottom'
24+
import { set } from 'lodash'
2425

2526
export const ChatV2 = () => {
2627
const { courseId } = useParams()
@@ -46,6 +47,7 @@ export const ChatV2 = () => {
4647
const [activePromptId, setActivePromptId] = useState('')
4748
const [fileName, setFileName] = useState<string>('')
4849
const [completion, setCompletion] = useState('')
50+
const [citations, setCitations] = useState<FileCitation[]>([])
4951
const [streamController, setStreamController] = useState<AbortController>()
5052
const [alertOpen, setAlertOpen] = useState(false)
5153
const [disallowedFileType, setDisallowedFileType] = useState('')
@@ -69,6 +71,7 @@ export const ChatV2 = () => {
6971
const reader = stream.getReader()
7072

7173
let content = ''
74+
const citations: FileCitation[] = []
7275

7376
while (true) {
7477
const { value, done } = await reader.read()
@@ -79,15 +82,22 @@ export const ChatV2 = () => {
7982
for (const chunk of data.split('\n')) {
8083
if (!chunk || chunk.trim().length === 0) continue
8184

82-
const parsedChunk: ResponseStreamValue = JSON.parse(chunk)
85+
const parsedChunk: ResponseStreamEventData = JSON.parse(chunk)
8386

84-
switch (parsedChunk.status) {
87+
switch (parsedChunk.type) {
8588
case 'writing':
8689
setCompletion((prev) => prev + parsedChunk.text)
8790
content += parsedChunk.text
8891
break
8992

93+
case 'annotation':
94+
console.log('Received annotation:', parsedChunk.annotation)
95+
setCitations((prev) => [...prev, parsedChunk.annotation])
96+
citations.push(parsedChunk.annotation)
97+
break
98+
9099
case 'complete':
100+
console.log('Stream completed with response ID:', parsedChunk)
91101
setPrevResponse({ id: parsedChunk.prevResponseId })
92102
break
93103

@@ -101,7 +111,7 @@ export const ChatV2 = () => {
101111
}
102112
}
103113

104-
setMessages((prev: Message[]) => prev.concat({ role: 'assistant', content }))
114+
setMessages((prev: Message[]) => prev.concat({ role: 'assistant', content, citations }))
105115
} catch (err: any) {
106116
handleCompletionStreamError(err, fileName)
107117
} finally {
@@ -120,6 +130,7 @@ export const ChatV2 = () => {
120130
setMessage({ content: '' })
121131
setPrevResponse({ id: '' })
122132
setCompletion('')
133+
setCitations([])
123134
setStreamController(new AbortController())
124135
setRetryTimeout(() => {
125136
if (streamController) {
@@ -193,7 +204,7 @@ export const ChatV2 = () => {
193204
{courseId ? <Link to={'/v2'}>CurreChat</Link> : <Link to={'/v2/sandbox'}>Ohtu Sandbox</Link>}
194205
</Box>
195206
<Box ref={chatContainerRef}>
196-
<Conversation messages={messages} completion={completion} />
207+
<Conversation messages={messages} completion={completion} citations={citations} />
197208
<ChatBox
198209
disabled={false}
199210
onSubmit={(message) => {

src/client/components/ChatV2/Conversation.tsx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { Message } from '../../types'
33
import ReactMarkdown from 'react-markdown'
44
import remarkGfm from 'remark-gfm'
55
import { Assistant } from '@mui/icons-material'
6+
import { FileCitation } from '../../../shared/types'
67

78
const MessageItem = ({ message }: { message: Message }) => (
89
<Paper
@@ -17,6 +18,16 @@ const MessageItem = ({ message }: { message: Message }) => (
1718
}}
1819
>
1920
<ReactMarkdown remarkPlugins={[remarkGfm]}>{message.content}</ReactMarkdown>
21+
{message.citations && message.citations.length > 0 && (
22+
<Box sx={{ mt: 1, fontSize: '0.875rem', color: 'gray' }}>
23+
Citations:
24+
{message.citations.map((citation, index) => (
25+
<Box key={index} sx={{ display: 'block' }}>
26+
{citation.filename}
27+
</Box>
28+
))}
29+
</Box>
30+
)}
2031
</Paper>
2132
)
2233

@@ -41,12 +52,12 @@ const PöhinäLogo = () => (
4152
</Box>
4253
)
4354

44-
export const Conversation = ({ messages, completion }: { messages: Message[]; completion: string }) => (
55+
export const Conversation = ({ messages, completion, citations }: { messages: Message[]; completion: string; citations: FileCitation[] }) => (
4556
<Box sx={{ flex: 1, overflowY: 'auto', gap: 2 }}>
4657
{messages.map((message, idx) => (
4758
<MessageItem key={idx} message={message} />
4859
))}
49-
{completion && <MessageItem message={{ role: 'assistant', content: completion }} />}
60+
{completion && <MessageItem message={{ role: 'assistant', content: completion, citations }} />}
5061
{messages.length === 0 && <PöhinäLogo />}
5162
</Box>
5263
)

src/client/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import { FileCitation } from "../shared/types"
2+
13
export type SetState<T> = React.Dispatch<React.SetStateAction<T>>
24

35
export type Role = 'system' | 'assistant' | 'user'
46

57
export interface Message {
68
role: Role
79
content: string
10+
citations?: FileCitation[]
811
}
912

1013
interface Term {

src/server/routes/openai.ts

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ const fileParsing = async (options: any, req: any) => {
4545
return options.messages
4646
}
4747

48-
openaiRouter.post('/stream/:version?', upload.single('file'), async (r, res) => {
48+
openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
4949
const req = r as RequestWithUser
50-
const { version } = r.params
5150
const { options, courseId } = JSON.parse(req.body.data)
5251
const { model, userConsent } = options
5352
const { user } = req
@@ -105,31 +104,134 @@ openaiRouter.post('/stream/:version?', upload.single('file'), async (r, res) =>
105104
model,
106105
})
107106

108-
let events
109-
if (version === 'v2') {
110-
const latestMessage = options.messages[options.messages.length - 1] // Adhoc to input only the latest message
111-
events = await responsesClient.createResponse({ input: [latestMessage], prevResponseId: options.prevResponseId })
112-
} else {
113-
events = await getCompletionEvents(options)
114-
}
107+
const latestMessage = options.messages[options.messages.length - 1] // Adhoc to input only the latest message
108+
const events = await responsesClient.createResponse({ input: [latestMessage], prevResponseId: options.prevResponseId })
109+
115110
if (isError(events)) {
116111
res.status(424)
117112
return
118113
}
119114

120115
res.setHeader('content-type', 'text/event-stream')
121116

122-
let completion
123-
if (version === 'v2') {
124-
completion = await responsesClient.handleResponse({
125-
events,
126-
encoding,
127-
res,
117+
const completion = await responsesClient.handleResponse({
118+
events,
119+
encoding,
120+
res,
121+
})
122+
123+
tokenCount += completion.tokenCount
124+
125+
let userToCharge = user
126+
if (inProduction && req.hijackedBy) {
127+
userToCharge = req.hijackedBy
128+
}
129+
130+
if (courseId) {
131+
await incrementCourseUsage(userToCharge, courseId, tokenCount)
132+
} else if (model !== FREE_MODEL) {
133+
await incrementUsage(userToCharge, tokenCount)
134+
}
135+
136+
logger.info(`Stream ended. Total tokens: ${tokenCount}`, {
137+
tokenCount,
138+
model,
139+
user: user.username,
140+
courseId,
141+
})
142+
143+
const course =
144+
courseId &&
145+
(await ChatInstance.findOne({
146+
where: { courseId },
147+
}))
148+
149+
const consentToSave = courseId && course.saveDiscussions && options.saveConsent
150+
151+
console.log('consentToSave', options.saveConsent, user.username)
152+
153+
if (consentToSave) {
154+
const discussion = {
155+
userId: user.id,
156+
courseId,
157+
response: completion.response,
158+
metadata: options,
159+
}
160+
await Discussion.create(discussion)
161+
}
162+
163+
encoding.free()
164+
165+
res.end()
166+
return
167+
})
168+
169+
openaiRouter.post('/stream', upload.single('file'), async (r, res) => {
170+
const req = r as RequestWithUser
171+
const { options, courseId } = JSON.parse(req.body.data)
172+
const { model, userConsent } = options
173+
const { user } = req
174+
175+
options.options = { temperature: options.modelTemperature }
176+
177+
if (!user.id) {
178+
res.status(401).send('Unauthorized')
179+
return
180+
}
181+
182+
const usageAllowed = courseId ? await checkCourseUsage(user, courseId) : model === FREE_MODEL || (await checkUsage(user, model))
183+
184+
if (!usageAllowed) {
185+
res.status(403).send('Usage limit reached')
186+
return
187+
}
188+
189+
let optionsMessagesWithFile = null
190+
191+
try {
192+
if (req.file) {
193+
optionsMessagesWithFile = await fileParsing(options, req)
194+
}
195+
} catch (error) {
196+
logger.error('Error parsing file', { error })
197+
res.status(400).send('Error parsing file')
198+
return
199+
}
200+
201+
options.messages = getMessageContext(optionsMessagesWithFile || options.messages)
202+
options.stream = true
203+
204+
const encoding = getEncoding(model)
205+
let tokenCount = calculateUsage(options, encoding)
206+
const tokenUsagePercentage = Math.round((tokenCount / DEFAULT_TOKEN_LIMIT) * 100)
207+
208+
if (model !== FREE_MODEL && tokenCount > 0.1 * DEFAULT_TOKEN_LIMIT && !userConsent) {
209+
res.status(201).json({
210+
tokenConsumtionWarning: true,
211+
message: `You are about to use ${tokenUsagePercentage}% of your monthly CurreChat usage`,
128212
})
129-
} else {
130-
completion = await streamCompletion(events, options, encoding, res)
213+
return
131214
}
132215

216+
const contextLimit = getModelContextLimit(model)
217+
218+
if (tokenCount > contextLimit) {
219+
logger.info('Maximum context reached')
220+
res.status(403).send('Model maximum context reached')
221+
return
222+
}
223+
224+
const events = await getCompletionEvents(options)
225+
226+
if (isError(events)) {
227+
res.status(424)
228+
return
229+
}
230+
231+
res.setHeader('content-type', 'text/event-stream')
232+
233+
const completion = await streamCompletion(events, options, encoding, res)
234+
133235
tokenCount += completion.tokenCount
134236

135237
let userToCharge = user

src/server/util/azure/ResponsesAPI.ts

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ import { AzureOpenAI } from 'openai'
1111

1212
// import { EventStream } from '@azure/openai'
1313
import { Stream } from 'openai/streaming'
14-
import { FileSearchTool, FunctionTool, ResponseInput, ResponseInputItem, ResponseStreamEvent } from 'openai/resources/responses/responses'
14+
import { FileSearchTool, FunctionTool, ResponseInput, ResponseInputItem, ResponseStreamEvent, ResponseTextAnnotationDeltaEvent } from 'openai/resources/responses/responses'
1515

1616
import { courseAssistants, type CourseAssistant } from './courseAssistants'
1717
import { createFileSearchTool } from './util'
1818

19-
import type { ResponseStreamValue } from '../../../shared/types'
19+
import type { FileCitation, ResponseStreamEventData } from '../../../shared/types'
2020

2121
const endpoint = `https://${AZURE_RESOURCE}.openai.azure.com/`
2222

@@ -92,9 +92,8 @@ export class ResponsesClient {
9292
case 'response.output_text.delta':
9393
await this.write(
9494
{
95-
status: 'writing',
95+
type: 'writing',
9696
text: event.delta,
97-
prevResponseId: null,
9897
},
9998
res,
10099
)
@@ -112,14 +111,19 @@ export class ResponsesClient {
112111
break
113112

114113
case 'response.output_text.annotation.added':
115-
console.log('ANNOTATIONS ADDED', JSON.stringify(event, null, 2))
114+
this.write(
115+
{
116+
type: 'annotation',
117+
annotation: event.annotation as FileCitation,
118+
},
119+
res,
120+
)
116121
break
117122

118123
case 'response.completed':
119124
await this.write(
120125
{
121-
status: 'complete',
122-
text: null,
126+
type: 'complete',
123127
prevResponseId: event.response.id,
124128
},
125129
res,
@@ -134,16 +138,8 @@ export class ResponsesClient {
134138
}
135139
}
136140

137-
private async write({ status, text, prevResponseId }: ResponseStreamValue, res: Response) {
138-
// if (!inProduction) logger.info(message)
139-
141+
private async write(data: ResponseStreamEventData, res: Response) {
140142
await new Promise((resolve) => {
141-
const data: ResponseStreamValue = {
142-
status,
143-
text,
144-
prevResponseId: prevResponseId,
145-
}
146-
147143
const success = res.write(JSON.stringify(data) + '\n', (err) => {
148144
if (err) {
149145
logger.error(err)

src/shared/types.ts

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,27 @@ export type RagIndexMetadata = {
33
dim: number
44
}
55

6-
export type ResponseStreamValue = {
7-
status: 'writing' | 'complete' | 'error'
8-
text: string | null
9-
prevResponseId: string | null
6+
export type FileCitation = {
7+
file_id: string
8+
filename: string
9+
index: number
10+
type: 'file_citation'
1011
}
12+
13+
export type ResponseStreamEventData =
14+
| {
15+
type: 'writing'
16+
text: string
17+
}
18+
| {
19+
type: 'complete'
20+
prevResponseId: string
21+
}
22+
| {
23+
type: 'error'
24+
error: any
25+
}
26+
| {
27+
type: 'annotation'
28+
annotation: FileCitation
29+
}

0 commit comments

Comments
 (0)