Skip to content

Commit 98fe6d9

Browse files
committed
improve v2 endpoint
1 parent 5d0f1ed commit 98fe6d9

File tree

8 files changed

+66
-146
lines changed

8 files changed

+66
-146
lines changed

src/server/routes/openai.ts

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ const PostStreamSchemaV2 = z.object({
6363
openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
6464
const req = r as RequestWithUser
6565
const { options, courseId } = PostStreamSchemaV2.parse(JSON.parse(req.body.data))
66-
const { userConsent } = options
66+
const { userConsent, ragIndexId } = options
6767
const { user } = req
6868

6969
console.log('options', options)
@@ -73,8 +73,26 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
7373
return
7474
}
7575

76+
const course =
77+
courseId &&
78+
(await ChatInstance.findOne({
79+
where: { courseId },
80+
}))
81+
82+
if (courseId && !course) {
83+
res.status(404).send('Course not found')
84+
return
85+
}
86+
7687
// Check if the user has usage limits for the course or model
77-
const usageAllowed = courseId ? await checkCourseUsage(user, courseId) : options.model === FREE_MODEL || (await checkUsage(user, options.model))
88+
let usageAllowed = false
89+
if (courseId) {
90+
usageAllowed = await checkCourseUsage(user, courseId)
91+
} else if (options.model === FREE_MODEL) {
92+
usageAllowed = true
93+
} else {
94+
usageAllowed = await checkUsage(user, options.model)
95+
}
7896

7997
if (!usageAllowed) {
8098
res.status(403).send('Usage limit reached')
@@ -123,6 +141,7 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
123141
return
124142
}
125143

144+
// Check context limit
126145
const contextLimit = getModelContextLimit(options.model)
127146

128147
if (tokenCount > contextLimit) {
@@ -131,25 +150,40 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
131150
return
132151
}
133152

153+
// Check rag index
134154
let vectorStoreId: string | undefined = undefined
135-
if (courseId) {
136-
const ragIndex = await RagIndex.findOne({
137-
where: { courseId },
138-
})
155+
let instructions: string | undefined = undefined
156+
157+
if (ragIndexId) {
158+
const ragIndex = await RagIndex.findByPk(ragIndexId)
139159
if (ragIndex) {
160+
if (courseId && ragIndex.courseId !== courseId) {
161+
logger.error('RagIndex does not belong to the course', { ragIndexId, courseId })
162+
res.status(403).send('RagIndex does not belong to the course')
163+
return
164+
}
165+
140166
vectorStoreId = ragIndex.metadata.azureVectorStoreId
167+
instructions = ragIndex.metadata.instructions
168+
141169
console.log('using', ragIndex.toJSON())
170+
} else {
171+
logger.error('RagIndex not found', { ragIndexId })
172+
res.status(404).send('RagIndex not found')
173+
return
142174
}
143175
}
144176

145177
const responsesClient = new ResponsesClient({
146178
model: options.model,
147179
courseId,
148180
vectorStoreId,
181+
instructions,
149182
})
150183

151-
const latestMessage = options.messages[options.messages.length - 1] // Adhoc to input only the latest message
152-
const events = await responsesClient.createResponse({ input: [latestMessage], prevResponseId: options.prevResponseId })
184+
const latestMessage = options.messages[options.messages.length - 1]
185+
186+
const events = await responsesClient.createResponse({ input: [latestMessage], prevResponseId: options.prevResponseId, include: ragIndexId ? ['file_search_call.results'] : [] })
153187

154188
if (isError(events)) {
155189
res.status(424)
@@ -184,12 +218,6 @@ openaiRouter.post('/stream/v2', upload.single('file'), async (r, res) => {
184218
courseId,
185219
})
186220

187-
const course =
188-
courseId &&
189-
(await ChatInstance.findOne({
190-
where: { courseId },
191-
}))
192-
193221
const consentToSave = courseId && course.saveDiscussions && options.saveConsent
194222

195223
console.log('consentToSave', options.saveConsent, user.username)

src/server/routes/rag.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ router.post('/indices', async (req, res) => {
4545
return
4646
}
4747

48-
const client = getAzureOpenAIClient('curredev4omini')
48+
const client = getAzureOpenAIClient()
4949
const vectorStore = await client.vectorStores.create({
5050
name,
5151
})
@@ -77,7 +77,7 @@ router.delete('/indices/:id', async (req, res) => {
7777
return
7878
}
7979

80-
const client = getAzureOpenAIClient('curredev4omini')
80+
const client = getAzureOpenAIClient()
8181
try {
8282
await client.vectorStores.del(ragIndex.metadata.azureVectorStoreId)
8383
} catch (error) {
@@ -123,7 +123,7 @@ router.get('/indices', async (req, res) => {
123123
})
124124

125125
if (includeExtras) {
126-
const client = getAzureOpenAIClient('curredev4omini')
126+
const client = getAzureOpenAIClient()
127127

128128
// Add ragFileCount to each index
129129
const indicesWithCount = await Promise.all(
@@ -159,7 +159,7 @@ router.get('/indices/:id', async (req, res) => {
159159
return
160160
}
161161

162-
const client = getAzureOpenAIClient('curredev4omini')
162+
const client = getAzureOpenAIClient()
163163
const vectorStore = await client.vectorStores.retrieve(ragIndex.metadata.azureVectorStoreId)
164164

165165
res.json({
@@ -265,7 +265,7 @@ router.post('/indices/:id/upload', [indexUploadDirMiddleware, uploadMiddleware],
265265
),
266266
)
267267

268-
const client = getAzureOpenAIClient('curredev4omini')
268+
const client = getAzureOpenAIClient()
269269

270270
const uploadDirPath = `${UPLOAD_DIR}/${id}`
271271

src/server/util/azure/ResponsesAPI.ts

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
import { Tiktoken } from '@dqbd/tiktoken'
22
import { Response } from 'express'
3-
import { isError } from '../parser'
43

54
import { AZURE_RESOURCE, AZURE_API_KEY } from '../config'
6-
import { validModels, inProduction } from '../../../config'
5+
import { validModels } from '../../../config'
76
import logger from '../logger'
87

98
import { APIError } from '../../types'
109
import { AzureOpenAI } from 'openai'
1110

1211
// import { EventStream } from '@azure/openai'
1312
import { Stream } from 'openai/streaming'
14-
import { FileSearchTool, FunctionTool, ResponseInput, ResponseInputItem, ResponseStreamEvent, ResponseTextAnnotationDeltaEvent } from 'openai/resources/responses/responses'
13+
import { FileSearchTool, ResponseIncludable, ResponseInput, ResponseStreamEvent } from 'openai/resources/responses/responses'
1514

16-
import { courseAssistants } from './courseAssistants'
1715
import { createFileSearchTool } from './util'
1816

19-
import type { CourseAssistant, FileCitation, ResponseStreamEventData } from '../../../shared/types'
17+
import type { FileCitation, ResponseStreamEventData } from '../../../shared/types'
2018

2119
const endpoint = `https://${AZURE_RESOURCE}.openai.azure.com/`
2220

@@ -34,21 +32,12 @@ export class ResponsesClient {
3432
model: string
3533
instructions: string
3634
tools: FileSearchTool[]
37-
courseAssistant: CourseAssistant
3835

39-
constructor({ model, courseId, vectorStoreId }: { model: string; courseId?: string; vectorStoreId?: string }) {
36+
constructor({ model, courseId, vectorStoreId, instructions }: { model: string; courseId?: string; vectorStoreId?: string; instructions?: string }) {
4037
const deploymentId = validModels.find((m) => m.name === model)?.deployment
4138

4239
if (!deploymentId) throw new Error(`Invalid model: ${model}, not one of ${validModels.map((m) => m.name).join(', ')}`)
4340

44-
if (courseId) {
45-
this.courseAssistant = courseAssistants.find((assistant) => assistant.course_id === courseId)
46-
47-
if (!this.courseAssistant) throw new Error(`No course assistant found for course ID: ${courseId}`)
48-
} else {
49-
this.courseAssistant = courseAssistants.find((assistant) => assistant.name === 'default')
50-
}
51-
5241
const fileSearchTool = courseId
5342
? [
5443
createFileSearchTool({
@@ -58,11 +47,19 @@ export class ResponsesClient {
5847
: [] // needs to retrun empty array for null
5948

6049
this.model = deploymentId
61-
this.instructions = this.courseAssistant.assistant_instruction
50+
this.instructions = instructions
6251
this.tools = fileSearchTool
6352
}
6453

65-
async createResponse({ input, prevResponseId }: { input: ResponseInput; prevResponseId?: string }): Promise<Stream<ResponseStreamEvent> | APIError> {
54+
async createResponse({
55+
input,
56+
prevResponseId,
57+
include,
58+
}: {
59+
input: ResponseInput
60+
prevResponseId?: string
61+
include?: ResponseIncludable[]
62+
}): Promise<Stream<ResponseStreamEvent> | APIError> {
6663
try {
6764
return await client.responses.create({
6865
model: this.model,
@@ -73,6 +70,7 @@ export class ResponsesClient {
7370
tools: this.tools,
7471
tool_choice: 'auto',
7572
store: true,
73+
include,
7674
})
7775
} catch (error: any) {
7876
logger.error(error)

src/server/util/azure/client.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { OpenAIClient, AzureKeyCredential, EventStream, ChatCompletions, GetEmbeddingsOptions } from '@azure/openai'
1+
import { OpenAIClient, AzureKeyCredential, EventStream, ChatCompletions } from '@azure/openai'
22
import { Tiktoken } from '@dqbd/tiktoken'
33
import { Response } from 'express'
44

@@ -12,7 +12,7 @@ const endpoint = `https://${AZURE_RESOURCE}.openai.azure.com/`
1212

1313
const oldClient = new OpenAIClient(endpoint, new AzureKeyCredential(AZURE_API_KEY))
1414

15-
export const getAzureOpenAIClient = (deployment: string) =>
15+
export const getAzureOpenAIClient = (deployment?: string) =>
1616
new AzureOpenAI({
1717
apiKey: AZURE_API_KEY,
1818
deployment,

src/server/util/azure/courseAssistants.ts

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/server/util/azure/fileSearchTools.ts

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/server/util/azure/functionTools.ts

Lines changed: 0 additions & 73 deletions
This file was deleted.

src/shared/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
export type RagIndexMetadata = {
22
name: string
3-
dim: number
3+
dim?: number
44
azureVectorStoreId: string
5+
instructions?: string
56
}
67

78
export type RagFileAttributes = {

0 commit comments

Comments
 (0)