Skip to content

Commit e7143bf

Browse files
committed
Refactor function calling and add authHandler middleware to /chat endpoint and updateUsername endpoint
1 parent 7acb45d commit e7143bf

File tree

5 files changed

+56
-55
lines changed

5 files changed

+56
-55
lines changed

course-matrix/backend/src/constants/availableFunctions.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ export const availableFunctions: AvailableFunctions = {
2222
.eq("user_id", user_id);
2323
const { data: timetableData, error: timetableError } =
2424
await timeTableQuery;
25+
// console.log("Timetables: ", timetableData)
2526

2627
if (timetableError) return { status: 400, error: timetableError.message };
2728

@@ -35,7 +36,8 @@ export const availableFunctions: AvailableFunctions = {
3536

3637
return { status: 200, data: timetableData };
3738
} catch (error) {
39+
console.log(error)
3840
return { status: 400, error: error };
3941
}
4042
},
41-
};
43+
};

course-matrix/backend/src/controllers/aiController.ts

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncHandler from "../middleware/asyncHandler";
22
import { Request, Response } from "express";
33
import { createOpenAI } from "@ai-sdk/openai";
4-
import { CoreMessage, streamText } from "ai";
4+
import { CoreMessage, generateObject, InvalidToolArgumentsError, NoSuchToolError, streamText, tool, ToolExecutionError } from "ai";
55
import { Index, Pinecone, RecordMetadata } from "@pinecone-database/pinecone";
66
import { PineconeStore } from "@langchain/pinecone";
77
import { OpenAIEmbeddings } from "@langchain/openai";
@@ -27,6 +27,7 @@ import {
2727
availableFunctions,
2828
FunctionNames,
2929
} from "../constants/availableFunctions";
30+
import { z } from "zod";
3031

3132
const openai = createOpenAI({
3233
baseURL: process.env.OPENAI_BASE_URL,
@@ -231,6 +232,8 @@ async function reformulateQuery(
231232
content: latestQuery,
232233
});
233234

235+
console.log(messages)
236+
234237
const response = await openai2.chat.completions.create({
235238
model: "gpt-4o-mini",
236239
messages: messages,
@@ -319,55 +322,6 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
319322
if (latestMessage.startsWith(CHATBOT_TIMETABLE_CMD)) {
320323
// ----- Flow 1 - Agent performs action on timetable -----
321324

322-
const openai2 = new OpenAI({
323-
apiKey: process.env.OPENAI_API_KEY,
324-
});
325-
326-
// Call with function definitions
327-
const response = await openai2.chat.completions.create({
328-
model: "gpt-4o-mini",
329-
messages,
330-
functions: [
331-
{
332-
name: "getTimetables",
333-
description:
334-
"Get all the timetables of the currently logged in user.",
335-
parameters: {
336-
type: "object",
337-
properties: {},
338-
required: [],
339-
},
340-
},
341-
],
342-
function_call: "auto",
343-
});
344-
345-
let responseMessage = response.choices[0].message;
346-
347-
console.log(responseMessage);
348-
349-
// Check if the model wants to call a tool/function
350-
if (responseMessage.function_call) {
351-
// Add the assistant's message with tool calls to conversation
352-
// messages.push(responseMessage);
353-
354-
// Process the tool call
355-
const toolCall = responseMessage.function_call;
356-
const functionName = toolCall.name as FunctionNames;
357-
const functionToCall = availableFunctions[functionName];
358-
const functionArgs = JSON.parse(toolCall.arguments);
359-
360-
if (functionToCall) {
361-
// Execute the function
362-
const functionResponse = await functionToCall(functionArgs, req);
363-
364-
// Add the tool result to the conversation
365-
messages.push({
366-
role: "tool",
367-
content: JSON.stringify(functionResponse),
368-
});
369-
}
370-
371325
// Get a new response from the model with all the tool responses
372326
const result = streamText({
373327
model: openai("gpt-4o-mini"),
@@ -397,10 +351,53 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
397351
- For unrelated questions, politely explain that you're specialized in UTSC academic information
398352
`,
399353
messages,
354+
tools: {
355+
getTimetables: tool({
356+
description: "Get all the timetables of the currently logged in user.",
357+
parameters: z.object({}),
358+
execute: async (args) => {
359+
return await availableFunctions.getTimetables(args, req);
360+
}
361+
})
362+
},
363+
maxSteps: 3, // Controls how many back and forths the model can take with user or calling multiple tools
364+
experimental_repairToolCall: async ({
365+
toolCall,
366+
tools,
367+
parameterSchema,
368+
error,
369+
}) => {
370+
if (NoSuchToolError.isInstance(error)) {
371+
return null; // do not attempt to fix invalid tool names
372+
}
373+
374+
const tool = tools[toolCall.toolName as keyof typeof tools];
375+
console.log(`The model tried to call the tool "${toolCall.toolName}"` +
376+
` with the following arguments:`,
377+
JSON.stringify(toolCall.args),
378+
`The tool accepts the following schema:`,
379+
JSON.stringify(parameterSchema(toolCall)),
380+
'Please fix the arguments.')
381+
382+
const { object: repairedArgs } = await generateObject({
383+
model: openai('gpt-4o', { structuredOutputs: true }),
384+
schema: tool.parameters,
385+
prompt: [
386+
`The model tried to call the tool "${toolCall.toolName}"` +
387+
` with the following arguments:`,
388+
JSON.stringify(toolCall.args),
389+
`The tool accepts the following schema:`,
390+
JSON.stringify(parameterSchema(toolCall)),
391+
'Please fix the arguments.',
392+
].join('\n'),
393+
});
394+
395+
return { ...toolCall, args: JSON.stringify(repairedArgs) };
396+
},
400397
});
401398

402399
result.pipeDataStreamToResponse(res);
403-
}
400+
404401
} else {
405402
// ----- Flow 2 - Answer query -----
406403

course-matrix/backend/src/routes/aiRouter.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import express from "express";
22
import { chat, testSimilaritySearch } from "../controllers/aiController";
3-
import { authRouter } from "./authRouter";
3+
import { authHandler } from "../middleware/authHandler";
44

55
export const aiRouter = express.Router();
66

77
/**
88
* @route POST /api/ai/chat
99
* @description Handles user queries and generates responses using GPT-4o, with optional knowledge retrieval.
1010
*/
11-
aiRouter.post("/chat", authRouter, chat);
11+
aiRouter.post("/chat", authHandler, chat);
1212
/**
1313
* @route POST /api/ai/test-similarity-search
1414
* @description Test vector database similarity search feature

course-matrix/backend/src/routes/authRouter.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
accountDelete,
1212
updateUsername,
1313
} from "../controllers/userController";
14+
import { authHandler } from "../middleware/authHandler";
1415

1516
export const authRouter = express.Router();
1617

@@ -66,4 +67,4 @@ authRouter.delete("/accountDelete", accountDelete);
6667
* Route to request to update username
6768
* @route POST /updateUsername
6869
*/
69-
authRouter.post("/updateUsername", updateUsername);
70+
authRouter.post("/updateUsername", authHandler, updateUsername);

course-matrix/frontend/src/pages/Assistant/runtime-provider.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export function RuntimeProvider({
6161
const runtime = useChatRuntime({
6262
cloud,
6363
api: `${SERVER_URL}/api/ai/chat`,
64+
credentials: "include"
6465
});
6566

6667
const contextValue = {

0 commit comments

Comments
 (0)