Skip to content

Commit 0d96a29

Browse files
committed
update tool call function
1 parent 1f6e587 commit 0d96a29

File tree

3 files changed

+170
-96
lines changed

3 files changed

+170
-96
lines changed

src/server/routes/openai.ts

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { CourseChatRequest, RequestWithUser } from '../types'
55
import { isError } from '../util/parser'
66
import { calculateUsage, incrementUsage, checkUsage, checkCourseUsage, incrementCourseUsage } from '../services/chatInstances/usage'
77
import { getCompletionEvents, streamCompletion } from '../util/azure/client'
8-
import { streamResponsesEvents, getResponsesEvents } from '../util/azure/clientV2'
8+
import { ResponsesClient } from '../util/azure/clientV2'
99
import { getMessageContext, getModelContextLimit, getCourseModel, getAllowedModels } from '../util/util'
1010
import getEncoding from '../util/tiktoken'
1111
import logger from '../util/logger'
@@ -100,17 +100,14 @@ openaiRouter.post('/stream/:version?', upload.single('file'), async (r, res) =>
100100
return
101101
}
102102

103+
const responsesClient = new ResponsesClient(model)
104+
103105
let events
104106
if (version === 'v2') {
105-
events = await getResponsesEvents({
106-
model: options.model,
107-
input: options.messages,
108-
stream: options.stream,
109-
})
107+
events = await responsesClient.createResponse({ input: options.messages })
110108
} else {
111109
events = await getCompletionEvents(options)
112110
}
113-
114111
if (isError(events)) {
115112
res.status(424)
116113
return
@@ -120,7 +117,12 @@ openaiRouter.post('/stream/:version?', upload.single('file'), async (r, res) =>
120117

121118
let completion
122119
if (version === 'v2') {
123-
completion = await streamResponsesEvents(events, encoding, res)
120+
completion = await responsesClient.handleResponse({
121+
events,
122+
prevMessages: options.messages,
123+
encoding,
124+
res,
125+
})
124126
} else {
125127
completion = await streamCompletion(events, options, encoding, res)
126128
}
@@ -214,13 +216,11 @@ openaiRouter.post('/stream/:courseId/:version?', upload.single('file'), async (r
214216
return
215217
}
216218

219+
const responsesClient = new ResponsesClient(model)
220+
217221
let events
218222
if (version === 'v2') {
219-
events = await getResponsesEvents({
220-
model: options.model,
221-
input: options.messages,
222-
stream: options.stream,
223-
})
223+
events = await responsesClient.createResponse({ input: options.messages })
224224
} else {
225225
events = await getCompletionEvents(options)
226226
}
@@ -234,7 +234,12 @@ openaiRouter.post('/stream/:courseId/:version?', upload.single('file'), async (r
234234

235235
let completion
236236
if (version === 'v2') {
237-
completion = await streamResponsesEvents(events, encoding, res)
237+
completion = await responsesClient.handleResponse({
238+
events,
239+
prevMessages: options.messages,
240+
encoding,
241+
res,
242+
})
238243
} else {
239244
completion = await streamCompletion(events, options, encoding, res)
240245
}

src/server/util/azure/clientV2.ts

Lines changed: 136 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { Tiktoken } from '@dqbd/tiktoken'
22
import { Response } from 'express'
3+
import { isError } from '../../util/parser'
34

45
import { AZURE_RESOURCE, AZURE_API_KEY } from '../config'
56
import { validModels, inProduction } from '../../../config'
@@ -9,7 +10,14 @@ import { APIError } from '../../types'
910
import { AzureOpenAI } from 'openai'
1011
// import { EventStream } from '@azure/openai'
1112
import { Stream } from 'openai/streaming'
12-
import { ResponseStreamEvent } from 'openai/resources/responses/responses'
13+
import {
14+
FunctionTool,
15+
ResponseInput,
16+
ResponseInputItem,
17+
ResponseStreamEvent,
18+
} from 'openai/resources/responses/responses'
19+
20+
import { testTool } from './tools'
1321

1422
const endpoint = `https://${AZURE_RESOURCE}.openai.azure.com/`
1523

@@ -23,90 +31,139 @@ export const getAzureOpenAIClient = (deployment: string) =>
2331

2432
const client = getAzureOpenAIClient(process.env.GPT_4O)
2533

26-
/**
27-
* Mock stream for testing
28-
*/
29-
// const getMockCompletionEvents: () => Promise<
30-
// EventStream<ResponseStreamEvent>
31-
// > = async () => {
32-
// const mockStream = new ReadableStream<ResponseStreamEvent>({
33-
// start(controller) {
34-
// for (let i = 0; i < 10; i += 1) {
35-
// controller.enqueue({
36-
// event: "response",
37-
// data: ""
38-
// })
39-
// }
40-
// controller.close()
41-
// },
42-
// }) as EventStream<ResponseStreamEvent>
43-
44-
// return mockStream
45-
// }
46-
47-
export const getResponsesEvents = async ({
48-
model,
49-
input,
50-
stream,
51-
}: any): Promise<
52-
| Stream<ResponseStreamEvent>
53-
// EventStream<ChatCompletionChunk>
54-
| APIError
55-
| any
56-
> => {
57-
const deploymentId = validModels.find((m) => m.name === model)?.deployment
58-
59-
if (!deploymentId) throw new Error(`Invalid model: ${model}, not one of ${validModels.map((m) => m.name).join(', ')}`)
60-
61-
// Mocking disabled because it's difficult to mock a event stream for responses API.
62-
// if (deploymentId === 'mock') return getMockCompletionEvents()
63-
64-
try {
65-
const events = await client.responses.create({
66-
model: deploymentId,
67-
instructions: 'Olet avulias apuri.',
68-
input,
69-
stream,
70-
tools: [],
71-
})
34+
export class ResponsesClient {
35+
model: string
36+
instructions: string
37+
tools: FunctionTool[]
7238

73-
return events
74-
} catch (error: any) {
75-
logger.error(error)
39+
constructor(model: string, instructions?: string) {
40+
const deploymentId = validModels.find((m) => m.name === model)?.deployment
7641

77-
return { error } as any as APIError
42+
if (!deploymentId)
43+
throw new Error(
44+
`Invalid model: ${model}, not one of ${validModels.map((m) => m.name).join(', ')}`
45+
)
46+
47+
this.model = deploymentId
48+
this.instructions = instructions || 'Olet avulias apuri.'
49+
this.tools = [testTool.definition]
7850
}
79-
}
8051

81-
export const streamResponsesEvents = async (events: Stream<ResponseStreamEvent>, encoding: Tiktoken, res: Response) => {
82-
let tokenCount = 0
83-
const contents = []
84-
85-
for await (const event of events) {
86-
switch (event.type) {
87-
case 'response.output_text.delta':
88-
if (!inProduction) logger.info(event.delta)
89-
90-
await new Promise((resolve) => {
91-
if (
92-
!res.write(event.delta, (err) => {
93-
if (err) logger.error(`${event.delta} ${err}`)
94-
})
95-
) {
96-
logger.info(`${event.delta} res.write returned false, waiting for drain`)
97-
res.once('drain', resolve)
98-
} else {
99-
process.nextTick(resolve)
100-
}
101-
})
102-
contents.push(event.delta)
103-
tokenCount += encoding.encode(event.delta).length ?? 0
104-
break
52+
async createResponse({
53+
input,
54+
}: {
55+
input: ResponseInput
56+
}): Promise<Stream<ResponseStreamEvent> | APIError> {
57+
try {
58+
return await client.responses.create({
59+
model: this.model,
60+
instructions: this.instructions,
61+
input,
62+
stream: true,
63+
tools: this.tools,
64+
})
65+
} catch (error: any) {
66+
logger.error(error)
67+
68+
return { error } as any as APIError
69+
}
70+
}
71+
72+
async handleResponse({
73+
events,
74+
prevMessages,
75+
encoding,
76+
res,
77+
}: {
78+
events: Stream<any>
79+
prevMessages: ResponseInput
80+
encoding: Tiktoken
81+
res: Response
82+
}) {
83+
let tokenCount = 0
84+
const contents = []
85+
86+
for await (const event of events) {
87+
console.log('event type:', event.type)
88+
89+
switch (event.type) {
90+
case 'response.output_text.delta':
91+
await this.writeDelta(event.delta, res)
92+
93+
contents.push(event.delta)
94+
tokenCount += encoding.encode(event.delta).length ?? 0
95+
break
96+
97+
case 'response.function_call_arguments.done':
98+
// WORK IN PROGRESS
99+
100+
// const augRetrieval = await this.callToolFunction(
101+
// event.arguments,
102+
// event.call_id
103+
// )
104+
// const newEvents = await this.createResponse({
105+
// input: [...prevMessages, augRetrieval],
106+
// })
107+
108+
// if (isError(events)) {
109+
// throw new Error(`Error creating response from function call`)
110+
// }
111+
112+
// await this.handleResponse({
113+
// events: newEvents as Stream<ResponseStreamEvent>,
114+
// prevMessages: [...prevMessages, augRetrieval],
115+
// encoding,
116+
// res,
117+
// })
118+
break
119+
}
105120
}
121+
122+
return {
123+
tokenCount,
124+
response: contents.join(''),
125+
}
126+
}
127+
128+
private async writeDelta(text: string, res: Response) {
129+
// if (!inProduction) logger.info(text)
130+
131+
await new Promise((resolve) => {
132+
if (
133+
!res.write(text, (err) => {
134+
if (err) logger.error(`${text} ${err}`)
135+
})
136+
) {
137+
logger.info(`${text} res.write returned false, waiting for drain`)
138+
res.once('drain', resolve)
139+
} else {
140+
process.nextTick(resolve)
141+
}
142+
})
106143
}
107144

108-
return {
109-
tokenCount,
110-
response: contents.join(''),
145+
private async callToolFunction(
146+
args: string,
147+
callId: string
148+
): Promise<ResponseInputItem[]> {
149+
const { query } = JSON.parse(args)
150+
try {
151+
const retrieval = await testTool.function(query)
152+
153+
return [
154+
{
155+
role: 'user',
156+
content: retrieval.query,
157+
},
158+
{
159+
type: 'function_call_output',
160+
call_id: callId,
161+
output: retrieval.result,
162+
},
163+
]
164+
} catch (error) {
165+
logger.error('Error calling tool function:', error)
166+
return null
167+
}
111168
}
112169
}

src/server/util/azure/tools.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
export const testTool = {
1+
import { FunctionTool } from 'openai/resources/responses/responses'
2+
3+
interface ToolObject {
4+
definition: FunctionTool
5+
function: (query: string) => Promise<any>
6+
}
7+
8+
export const testTool: ToolObject = {
29
definition: {
310
type: 'function',
411
name: 'test_knowledge_retrieval',
@@ -13,13 +20,18 @@ export const testTool = {
1320
},
1421
},
1522
required: ['query'],
23+
additionalProperties: false,
1624
},
25+
strict: true, // or true, depending on your requirements
1726
},
18-
function: async (query: string) => {
27+
function: async (
28+
query: string
29+
): Promise<{ query: string; result: string }> => {
1930
// Simulate a tool function that returns a simple message
2031
return {
2132
query,
22-
message: 'This is a test tool function',
33+
result:
34+
'This is a test result from the test tool. The secret is: Chili kastike',
2335
}
2436
},
2537
}

0 commit comments

Comments
 (0)