Skip to content

Commit 4cbd858

Browse files
committed
some refactoring of ai backend & comments & typing
1 parent 66f3526 commit 4cbd858

File tree

2 files changed

+49
-42
lines changed

2 files changed

+49
-42
lines changed

src/server/routes/openai.ts

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,50 @@ const openaiRouter = express.Router()
2121
const storage = multer.memoryStorage()
2222
const upload = multer({ storage })
2323

24-
const fileParsing = async (options: any, req: any) => {
24+
const PostStreamSchemaV2 = z.object({
25+
options: z.object({
26+
model: z.string(),
27+
assistantInstructions: z.string().optional(),
28+
messages: z.array(z.any()),
29+
userConsent: z.boolean().optional(),
30+
modelTemperature: z.number().min(0).max(2),
31+
saveConsent: z.boolean().optional(),
32+
prevResponseId: z.string().optional(),
33+
courseId: z.string().optional(),
34+
ragIndexId: z.number().optional().nullable(),
35+
}),
36+
courseId: z.string().optional(),
37+
})
38+
39+
type PostStreamBody = z.infer<typeof PostStreamSchemaV2>
40+
41+
const parseFileAndAddToLastMessage = async (options: PostStreamBody['options'], file: Express.Multer.File) => {
2542
let fileContent = ''
2643

2744
const textFileTypes = ['text/plain', 'text/html', 'text/css', 'text/csv', 'text/markdown', 'text/md']
28-
if (textFileTypes.includes(req.file.mimetype)) {
29-
const fileBuffer = req.file.buffer
45+
if (textFileTypes.includes(file.mimetype)) {
46+
const fileBuffer = file.buffer
3047
fileContent = fileBuffer.toString('utf8')
3148
}
3249

33-
if (req.file.mimetype === 'application/pdf') {
34-
fileContent = await pdfToText(req.file.buffer)
50+
if (file.mimetype === 'application/pdf') {
51+
fileContent = await pdfToText(file.buffer)
3552
}
3653

37-
const allMessages = options.messages
54+
const messageToAddFileTo = options.messages[options.messages.length - 1]
3855

3956
const updatedMessage = {
40-
...allMessages[allMessages.length - 1],
41-
content: `${allMessages[allMessages.length - 1].content} ${fileContent}`,
57+
...messageToAddFileTo,
58+
content: `${messageToAddFileTo.content} ${fileContent}`,
4259
}
43-
options.messages.pop()
4460

61+
// Remove the old message and add the new one
62+
options.messages.pop()
4563
options.messages = [...options.messages, updatedMessage]
4664

4765
return options.messages
4866
}
4967

50-
const PostStreamSchemaV2 = z.object({
51-
options: z.object({
52-
model: z.string(),
53-
assistantInstructions: z.string().optional(),
54-
messages: z.array(z.any()),
55-
userConsent: z.boolean().optional(),
56-
modelTemperature: z.number().min(0).max(2),
57-
saveConsent: z.boolean().optional(),
58-
prevResponseId: z.string().optional(),
59-
courseId: z.string().optional(),
60-
ragIndexId: z.number().optional().nullable(),
61-
}),
62-
courseId: z.string().optional(),
63-
})
64-
6568
openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
6669
const req = r as RequestWithUser
6770
const { options, courseId } = PostStreamSchemaV2.parse(JSON.parse(req.body.data))
@@ -82,7 +85,8 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
8285
throw ApplicationError.NotFound('Course not found')
8386
}
8487

85-
const usageAllowed = (courseId ? await checkCourseUsage(user, courseId) : model === FREE_MODEL) || (await checkUsage(user, model))
88+
const isFreeModel = model === FREE_MODEL
89+
const usageAllowed = (courseId ? await checkCourseUsage(user, courseId) : isFreeModel) || (await checkUsage(user, model))
8690

8791
if (!usageAllowed) {
8892
throw ApplicationError.Forbidden('Usage limit reached')
@@ -107,7 +111,7 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
107111

108112
try {
109113
if (req.file) {
110-
optionsMessagesWithFile = await fileParsing(options, req)
114+
optionsMessagesWithFile = await parseFileAndAddToLastMessage(options, req)
111115
}
112116
} catch (error) {
113117
logger.error('Error parsing file', { error })
@@ -120,7 +124,7 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
120124
let tokenCount = calculateUsage(options as any, encoding)
121125
const tokenUsagePercentage = Math.round((tokenCount / DEFAULT_TOKEN_LIMIT) * 100)
122126

123-
if (options.model !== FREE_MODEL && tokenCount > 0.1 * DEFAULT_TOKEN_LIMIT) {
127+
if (!isFreeModel && tokenCount > 0.1 * DEFAULT_TOKEN_LIMIT) {
124128
res.status(201).json({
125129
tokenConsumtionWarning: true,
126130
message: `You are about to use ${tokenUsagePercentage}% of your monthly CurreChat usage`,
@@ -132,7 +136,6 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
132136
const contextLimit = getModelContextLimit(options.model)
133137

134138
if (tokenCount > contextLimit) {
135-
logger.info('Maximum context reached') // @todo sure we need to log twice the error message?
136139
throw ApplicationError.BadRequest('Model maximum context reached')
137140
}
138141

@@ -142,7 +145,6 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
142145

143146
if (ragIndexId) {
144147
if (!courseId && !user.isAdmin) {
145-
logger.error('User is not admin and trying to access non-course rag')
146148
throw ApplicationError.Forbidden('User is not admin and trying to access non-course rag')
147149
}
148150

@@ -172,6 +174,7 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
172174
user,
173175
})
174176

177+
// Using the responses API, we only send the last message and the id to previous message
175178
const latestMessage = options.messages[options.messages.length - 1]
176179

177180
const events = await responsesClient.createResponse({
@@ -180,6 +183,7 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
180183
include: ragIndexId ? ['file_search_call.results'] : [],
181184
})
182185

186+
// Prepare for streaming response
183187
res.setHeader('content-type', 'text/event-stream')
184188

185189
const result = await responsesClient.handleResponse({
@@ -190,15 +194,19 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
190194

191195
tokenCount += result.tokenCount
192196

193-
let userToCharge = user
194-
if (inProduction && req.hijackedBy) {
195-
userToCharge = req.hijackedBy
196-
}
197+
// Increment user usage if not using free model
198+
// If the user is hijacked by admin in production, charge the admin instead
199+
if (!isFreeModel) {
200+
let userToCharge = user
201+
if (inProduction && req.hijackedBy) {
202+
userToCharge = req.hijackedBy
203+
}
197204

198-
if (courseId && model !== FREE_MODEL && model !== 'mock') {
199-
await incrementCourseUsage(userToCharge, courseId, tokenCount)
200-
} else if (model !== FREE_MODEL && model !== 'mock') {
201-
await incrementUsage(userToCharge, tokenCount)
205+
if (courseId) {
206+
await incrementCourseUsage(userToCharge, courseId, tokenCount)
207+
} else {
208+
await incrementUsage(userToCharge, tokenCount)
209+
}
202210
}
203211

204212
logger.info(`Stream ended. Total tokens: ${tokenCount}`, {
@@ -208,10 +216,9 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
208216
courseId,
209217
})
210218

211-
const consentToSave = courseId && course!.saveDiscussions && options.saveConsent
212-
219+
// If course has saveDiscussion turned on and user has consented to saving the discussion, save the discussion
220+
const consentToSave = courseId && course?.saveDiscussions && options.saveConsent
213221
console.log(`Consent to save discussion: ${options.saveConsent} ${user.username}`)
214-
215222
if (consentToSave) {
216223
// @todo: should file search results also be saved?
217224
const discussion = {
@@ -261,7 +268,7 @@ openaiRouter.post('/stream', upload.single('file'), async (r, res) => {
261268

262269
try {
263270
if (req.file) {
264-
optionsMessagesWithFile = await fileParsing(options, req)
271+
optionsMessagesWithFile = await parseFileAndAddToLastMessage(options, req)
265272
}
266273
} catch (error) {
267274
logger.error('Error parsing file', { error })

src/server/util/azure/ResponsesAPI.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ export class ResponsesClient {
7676
prevResponseId?: string
7777
include?: ResponseIncludable[]
7878
attemptNumber?: number
79-
}): Promise<Stream<ResponseStreamEvent> | APIError> {
79+
}): Promise<Stream<ResponseStreamEvent>> {
8080
try {
8181
const sanitizedInput = validatedInputSchema.parse(input)
8282

0 commit comments

Comments
 (0)