Skip to content

Commit cd87c8b

Browse files
committed
Add generateTimetable and createTimetable function call
1 parent 8fd91ed commit cd87c8b

File tree

4 files changed

+288
-53
lines changed

4 files changed

+288
-53
lines changed

course-matrix/backend/src/constants/availableFunctions.ts

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
import {
2+
categorizeValidOfferings,
3+
getMaxDays,
4+
getOfferings,
5+
getValidOfferings,
6+
getValidSchedules,
7+
GroupedOfferingList,
8+
groupOfferings,
9+
Offering,
10+
OfferingList,
11+
trim,
12+
} from "../controllers/generatorController";
113
import { supabase } from "../db/setupDb";
214
import { Request } from "express";
315

@@ -6,6 +18,7 @@ export type FunctionNames =
618
| "getTimetables"
719
| "updateTimetable"
820
| "deleteTimetable"
21+
| "createTimetable"
922
| "generateTimetable";
1023

1124
type AvailableFunctions = {
@@ -169,7 +182,133 @@ export const availableFunctions: AvailableFunctions = {
169182
return { status: 500, error: error };
170183
}
171184
},
185+
createTimetable: async (args: any, req: Request) => {
186+
try {
187+
//Get user id from session authentication to insert in the user_id col
188+
const user_id = (req as any).user.id;
189+
190+
//Retrieve timetable title
191+
const { timetable_title, semester, favorite = false } = args;
192+
if (!timetable_title || !semester) {
193+
return {
194+
status: 400,
195+
error: "timetable title and semester are required",
196+
};
197+
}
198+
199+
// Check if a timetable with the same title already exist for this user
200+
const { data: existingTimetable, error: existingTimetableError } =
201+
await supabase
202+
.schema("timetable")
203+
.from("timetables")
204+
.select("id")
205+
.eq("user_id", user_id)
206+
.eq("timetable_title", timetable_title)
207+
.maybeSingle();
208+
209+
if (existingTimetableError) {
210+
return { status: 400, error: existingTimetableError.message };
211+
}
212+
213+
if (existingTimetable) {
214+
return {
215+
status: 400,
216+
error: "A timetable with this title already exists",
217+
};
218+
}
219+
220+
//Create query to insert the user_id and timetable_title into the db
221+
let insertTimetable = supabase
222+
.schema("timetable")
223+
.from("timetables")
224+
.insert([
225+
{
226+
user_id,
227+
timetable_title,
228+
semester,
229+
favorite,
230+
},
231+
])
232+
.select()
233+
.single();
234+
235+
const { data: timetableData, error: timetableError } =
236+
await insertTimetable;
237+
238+
if (timetableError) {
239+
return { status: 400, error: timetableError.message };
240+
}
241+
242+
return { status: 201, data: timetableData };
243+
} catch (error) {
244+
return { status: 500, error };
245+
}
246+
},
172247
generateTimetable: async (args: any, req: Request) => {
173-
174-
}
248+
try {
249+
// Extract event details and course information from the request
250+
const { name, date, semester, search, courses, restrictions } = args;
251+
const courseOfferingsList: OfferingList[] = [];
252+
const validCourseOfferingsList: GroupedOfferingList[] = [];
253+
const maxdays = await getMaxDays(restrictions);
254+
const validSchedules: Offering[][] = [];
255+
// Fetch offerings for each course
256+
for (const course of courses) {
257+
const { id } = course;
258+
courseOfferingsList.push({
259+
course_id: id,
260+
offerings: (await getOfferings(id, semester)) ?? [],
261+
});
262+
}
263+
264+
const groupedOfferingsList: GroupedOfferingList[] = await groupOfferings(
265+
courseOfferingsList
266+
);
267+
268+
// console.log(JSON.stringify(groupedOfferingsList, null, 2));
269+
270+
// Filter out invalid offerings based on the restrictions
271+
for (const { course_id, groups } of groupedOfferingsList) {
272+
validCourseOfferingsList.push({
273+
course_id: course_id,
274+
groups: await getValidOfferings(groups, restrictions),
275+
});
276+
}
277+
278+
const categorizedOfferings = await categorizeValidOfferings(
279+
validCourseOfferingsList
280+
);
281+
282+
// console.log(typeof categorizedOfferings);
283+
// console.log(JSON.stringify(categorizedOfferings, null, 2));
284+
285+
// Generate valid schedules for the given courses and restrictions
286+
await getValidSchedules(
287+
validSchedules,
288+
categorizedOfferings,
289+
[],
290+
0,
291+
categorizedOfferings.length,
292+
maxdays
293+
);
294+
295+
// Return error if no valid schedules are found
296+
if (validSchedules.length === 0) {
297+
return { status: 404, error: "No valid schedules found." };
298+
}
299+
300+
// MODIFIED FOR TOOL CALL: Return single schedule
301+
return {
302+
status: 200,
303+
data: {
304+
schedule: trim(validSchedules)[0],
305+
},
306+
};
307+
} catch (error) {
308+
// Catch any error and return the error message
309+
const errorMessage =
310+
error instanceof Error ? error.message : "An unknown error occurred";
311+
return { status: 500, error: errorMessage };
312+
}
313+
},
175314
};

course-matrix/backend/src/controllers/aiController.ts

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import { z } from "zod";
4040
import { analyzeQuery } from "../utils/analyzeQuery";
4141
import { includeFilters } from "../utils/includeFilters";
4242
import { TimetableFormSchema } from "../models/timetable-form";
43+
import { CreateTimetableArgs } from "../models/timetable-generate";
4344

4445
const openai = createOpenAI({
4546
baseURL: process.env.OPENAI_BASE_URL,
@@ -55,7 +56,7 @@ const pinecone = new Pinecone({
5556
});
5657

5758
const index: Index<RecordMetadata> = pinecone.Index(
58-
process.env.PINECONE_INDEX_NAME!,
59+
process.env.PINECONE_INDEX_NAME!
5960
);
6061

6162
console.log("Connected to OpenAI API");
@@ -64,7 +65,7 @@ export async function searchSelectedNamespaces(
6465
query: string,
6566
k: number,
6667
namespaces: string[],
67-
filters?: Object,
68+
filters?: Object
6869
): Promise<Document[]> {
6970
let allResults: Document[] = [];
7071

@@ -86,7 +87,7 @@ export async function searchSelectedNamespaces(
8687
const results = await namespaceStore.similaritySearch(
8788
query,
8889
Math.max(k, namespaceToMinResults.get(namespace)),
89-
namespace === "courses_v3" ? filters : undefined,
90+
namespace === "courses_v3" ? filters : undefined
9091
);
9192
console.log(`Found ${results.length} results in namespace: ${namespace}`);
9293
allResults = [...allResults, ...results];
@@ -107,7 +108,7 @@ export async function searchSelectedNamespaces(
107108
// Reformulate user query to make more concise query to database, taking into consideration context
108109
export async function reformulateQuery(
109110
latestQuery: string,
110-
conversationHistory: any[],
111+
conversationHistory: any[]
111112
): Promise<string> {
112113
try {
113114
const openai2 = new OpenAI({
@@ -248,7 +249,8 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
248249
249250
## Tool call guidelines
250251
- Include the timetable ID in all getTimetbles tool call responses
251-
- If the tool call is a getTimetables call, then at the end of each timetable listed, include a link displayed as "View timetable" to ${process.env.CLIENT_APP_URL}/dashboard/timetable?edit=[[TIMETABLE_ID]] , where TIMETABLE_ID is the id of the respective timetable.
252+
- For every tool call, for each timetable that it gets/deletes/modifies/creates, include a link underneath it displayed as "View timetable" to ${process.env.CLIENT_APP_URL}/dashboard/timetable?edit=[[TIMETABLE_ID]] , where TIMETABLE_ID is the id of the respective timetable.
253+
- If the user provides a course code of length 6 like CSCA08, then assume they mean CSCA08H3 (H3 appended)
252254
`,
253255
messages,
254256
tools: {
@@ -282,13 +284,23 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
282284
return await availableFunctions.deleteTimetable(args, req);
283285
},
284286
}),
287+
createTimetable: tool({
288+
description:
289+
"Create a timetable with the provided meeting sections",
290+
parameters: CreateTimetableArgs,
291+
execute: async (args) => {
292+
return await availableFunctions.createTimetable(args, req);
293+
},
294+
}),
285295
generateTimetable: tool({
286-
description: "Generate a timetable based on selected courses and restrictions",
296+
description:
297+
"Return a list of possible timetables based on provided courses and restrictions",
287298
parameters: TimetableFormSchema,
288299
execute: async (args) => {
289-
return await availableFunctions.generateTimetable(args, req)
290-
}
291-
})
300+
console.log("Args: ", args);
301+
return await availableFunctions.generateTimetable(args, req);
302+
},
303+
}),
292304
},
293305
maxSteps: CHATBOT_TOOL_CALL_MAX_STEPS, // Controls how many back and forths the model can take with user or calling multiple tools
294306
experimental_repairToolCall: async ({
@@ -305,10 +317,10 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
305317
console.log(
306318
`The model tried to call the tool "${toolCall.toolName}"` +
307319
` with the following arguments:`,
308-
JSON.stringify(toolCall.args),
320+
toolCall.args,
309321
`The tool accepts the following schema:`,
310-
JSON.stringify(parameterSchema(toolCall)),
311-
"Please fix the arguments.",
322+
parameterSchema(toolCall),
323+
"Please fix the arguments."
312324
);
313325

314326
const { object: repairedArgs } = await generateObject({
@@ -343,7 +355,7 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
343355
// Use GPT-4o to reformulate the query based on conversation history
344356
const reformulatedQuery = await reformulateQuery(
345357
latestMessage,
346-
conversationHistory.slice(-CHATBOT_MEMORY_THRESHOLD), // last K messages
358+
conversationHistory.slice(-CHATBOT_MEMORY_THRESHOLD) // last K messages
347359
);
348360
console.log(">>>> Original query:", latestMessage);
349361
console.log(">>>> Reformulated query:", reformulatedQuery);
@@ -357,8 +369,8 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
357369
if (requiresSearch) {
358370
console.log(
359371
`Query requires knowledge retrieval, searching namespaces: ${relevantNamespaces.join(
360-
", ",
361-
)}`,
372+
", "
373+
)}`
362374
);
363375

364376
const filters = includeFilters(reformulatedQuery);
@@ -369,7 +381,7 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
369381
reformulatedQuery,
370382
3,
371383
relevantNamespaces,
372-
Object.keys(filters).length === 0 ? undefined : filters,
384+
Object.keys(filters).length === 0 ? undefined : filters
373385
);
374386
// console.log("Search Results: ", searchResults);
375387

@@ -379,7 +391,7 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
379391
}
380392
} else {
381393
console.log(
382-
"Query does not require knowledge retrieval, skipping search",
394+
"Query does not require knowledge retrieval, skipping search"
383395
);
384396
}
385397

@@ -444,15 +456,15 @@ export const testSimilaritySearch = asyncHandler(
444456
if (requiresSearch) {
445457
console.log(
446458
`Query requires knowledge retrieval, searching namespaces: ${relevantNamespaces.join(
447-
", ",
448-
)}`,
459+
", "
460+
)}`
449461
);
450462

451463
// Search only the relevant namespaces
452464
const searchResults = await searchSelectedNamespaces(
453465
message,
454466
3,
455-
relevantNamespaces,
467+
relevantNamespaces
456468
);
457469
console.log("Search Results: ", searchResults);
458470

@@ -462,11 +474,11 @@ export const testSimilaritySearch = asyncHandler(
462474
}
463475
} else {
464476
console.log(
465-
"Query does not require knowledge retrieval, skipping search",
477+
"Query does not require knowledge retrieval, skipping search"
466478
);
467479
}
468480

469481
console.log("CONTEXT: ", context);
470482
res.status(200).send(context);
471-
},
483+
}
472484
);

0 commit comments

Comments
 (0)