Skip to content

Commit 4cb5b39

Browse files
committed
Improved querying of year level and breadth requirement from vector db
1 parent 8de4414 commit 4cb5b39

File tree

7 files changed

+182
-19
lines changed

7 files changed

+182
-19
lines changed

course-matrix/backend/src/constants/constants.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ export const yearToCode = (year: number) => {
3939

4040
// Set minimum results wanted for a similarity search on the associated namespace.
4141
export const namespaceToMinResults = new Map();
42-
namespaceToMinResults.set("courses_v2", 10);
42+
namespaceToMinResults.set("courses_v3", 10);
4343
namespaceToMinResults.set("offerings", 16); // Typically, more offering info is wanted.
4444
namespaceToMinResults.set("prerequisites", 5);
4545
namespaceToMinResults.set("corequisites", 5);

course-matrix/backend/src/constants/promptKeywords.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// Keywords related to each namespace
22
export const NAMESPACE_KEYWORDS = {
3-
courses_v2: [
3+
courses_v3: [
44
"course",
55
"class",
66
"description",
@@ -61,6 +61,41 @@ export const NAMESPACE_KEYWORDS = {
6161
programs: ["program", "major", "minor", "specialist", "degree", "stream"],
6262
};
6363

64+
export const BREADTH_REQUIREMENT_KEYWORDS = {
65+
ART_LIT_LANG: [
66+
"ART_LIT_LANG",
67+
"art literature",
68+
"arts literature",
69+
"art language",
70+
"arts language",
71+
"literature language",
72+
"art literature language",
73+
"arts literature language",
74+
],
75+
HIS_PHIL_CUL: [
76+
"HIS_PHIL_CUL",
77+
"history philosophy culture",
78+
"history, philosophy, culture",
79+
"history, philosophy, and culture",
80+
"history, philosophy",
81+
"history philosophy",
82+
"philosophy culture",
83+
"philosophy, culture",
84+
"history culture",
85+
"History, Philosophy and Cultural Studies",
86+
],
87+
SOCIAL_SCI: ["SOCIAL_SCI", "social science", "social sciences"],
88+
NAT_SCI: ["NAT_SCI", "natural science", "natural sciences"],
89+
QUANT: ["QUANT", "quantitative reasoning"],
90+
};
91+
92+
export const YEAR_LEVEL_KEYWORDS = {
93+
first_year: ["first year", "first-year", "A-level", "A level", "1st year"],
94+
second_year: ["second year", "second-year", "B-level", "B level", "2nd year"],
95+
third_year: ["third year", "third-year", "C-level", "C level", "3rd year"],
96+
fourth_year: ["fourth year", "fourth-year", "D-level", "D level", "4th year"],
97+
};
98+
6499
// General academic terms that might indicate a search is needed
65100
export const GENERAL_ACADEMIC_TERMS = ["credit", "enroll", "drop"];
66101

course-matrix/backend/src/controllers/aiController.ts

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@ import {
1212
DEPARTMENT_CODES,
1313
ASSISTANT_TERMS,
1414
USEFUL_INFO,
15+
BREADTH_REQUIREMENT_KEYWORDS,
16+
YEAR_LEVEL_KEYWORDS,
1517
} from "../constants/promptKeywords";
1618
import { CHATBOT_MEMORY_THRESHOLD, codeToYear } from "../constants/constants";
1719
import { namespaceToMinResults } from "../constants/constants";
1820
import OpenAI from "openai";
21+
import { convertBreadthRequirement } from "../utils/convert-breadth-requirement";
22+
import { convertYearLevel } from "../utils/convert-year-level";
1923

2024
const openai = createOpenAI({
2125
baseURL: process.env.OPENAI_BASE_URL,
@@ -31,7 +35,7 @@ const pinecone = new Pinecone({
3135
});
3236

3337
const index: Index<RecordMetadata> = pinecone.Index(
34-
process.env.PINECONE_INDEX_NAME!,
38+
process.env.PINECONE_INDEX_NAME!
3539
);
3640

3741
console.log("Connected to OpenAI API");
@@ -58,8 +62,8 @@ function analyzeQuery(query: string): {
5862

5963
// If a course code is detected, add tehse namespaces
6064
if (containsCourseCode) {
61-
if (!relevantNamespaces.includes("courses_v2"))
62-
relevantNamespaces.push("courses_v2");
65+
if (!relevantNamespaces.includes("courses_v3"))
66+
relevantNamespaces.push("courses_v3");
6367
if (!relevantNamespaces.includes("offerings"))
6468
relevantNamespaces.push("offerings");
6569
if (!relevantNamespaces.includes("prerequisites"))
@@ -70,8 +74,8 @@ function analyzeQuery(query: string): {
7074
if (DEPARTMENT_CODES.some((code) => lowerQuery.includes(code))) {
7175
if (!relevantNamespaces.includes("departments"))
7276
relevantNamespaces.push("departments");
73-
if (!relevantNamespaces.includes("courses_v2"))
74-
relevantNamespaces.push("courses_v2");
77+
if (!relevantNamespaces.includes("courses_v3"))
78+
relevantNamespaces.push("courses_v3");
7579
}
7680

7781
// If search is required at all
@@ -83,12 +87,12 @@ function analyzeQuery(query: string): {
8387
// If no specific namespaces identified & search required, then search all
8488
if (requiresSearch && relevantNamespaces.length === 0) {
8589
relevantNamespaces.push(
86-
"courses_v2",
90+
"courses_v3",
8791
"offerings",
8892
"prerequisites",
8993
"corequisites",
9094
"departments",
91-
"programs",
95+
"programs"
9296
);
9397
}
9498

@@ -106,6 +110,7 @@ async function searchSelectedNamespaces(
106110
query: string,
107111
k: number,
108112
namespaces: string[],
113+
filters?: Object
109114
): Promise<Document[]> {
110115
let allResults: Document[] = [];
111116

@@ -127,6 +132,7 @@ async function searchSelectedNamespaces(
127132
const results = await namespaceStore.similaritySearch(
128133
query,
129134
Math.max(k, namespaceToMinResults.get(namespace)),
135+
namespace === "courses_v3" ? filters : undefined
130136
);
131137
console.log(`Found ${results.length} results in namespace: ${namespace}`);
132138
allResults = [...allResults, ...results];
@@ -147,7 +153,7 @@ async function searchSelectedNamespaces(
147153
// Reformulate user query to make more concise query to database, taking into consideration context
148154
async function reformulateQuery(
149155
latestQuery: string,
150-
conversationHistory: any[],
156+
conversationHistory: any[]
151157
): Promise<string> {
152158
try {
153159
const openai = new OpenAI({
@@ -227,6 +233,69 @@ async function reformulateQuery(
227233
}
228234
}
229235

236+
// Determines whether to apply metadata filtering based on user query.
237+
function includeFilters(query: string) {
238+
const lowerQuery = query.toLocaleLowerCase();
239+
const relaventBreadthRequirements: string[] = [];
240+
const relaventYearLevels: string[] = [];
241+
242+
Object.entries(BREADTH_REQUIREMENT_KEYWORDS).forEach(
243+
([namespace, keywords]) => {
244+
if (keywords.some((keyword) => lowerQuery.includes(keyword))) {
245+
relaventBreadthRequirements.push(convertBreadthRequirement(namespace));
246+
}
247+
}
248+
);
249+
250+
Object.entries(YEAR_LEVEL_KEYWORDS).forEach(([namespace, keywords]) => {
251+
if (keywords.some((keyword) => lowerQuery.includes(keyword))) {
252+
relaventYearLevels.push(convertYearLevel(namespace));
253+
}
254+
});
255+
256+
let filter = {};
257+
if (relaventBreadthRequirements.length > 0 && relaventYearLevels.length > 0) {
258+
filter = {
259+
$and: [
260+
{
261+
$or: relaventBreadthRequirements.map((req) => ({
262+
breadth_requirement: { $eq: req },
263+
})),
264+
},
265+
{
266+
$or: relaventYearLevels.map((yl) => ({ year_level: { $eq: yl } })),
267+
},
268+
],
269+
};
270+
} else if (relaventBreadthRequirements.length > 0) {
271+
filter = {
272+
$or: relaventBreadthRequirements.map((req) => ({
273+
breadth_requirement: { $eq: req },
274+
})),
275+
};
276+
} else if (relaventYearLevels.length > 0) {
277+
filter = {
278+
$or: relaventYearLevels.map((yl) => ({ year_level: { $eq: yl } })),
279+
};
280+
}
281+
return filter;
282+
}
283+
284+
/**
285+
* @description Handles user queries and generates responses using GPT-4o, with optional knowledge retrieval.
286+
*
287+
* @param {Request} req - The Express request object, containing:
288+
* @param {Object[]} req.body.messages - Array of message objects representing the conversation history.
289+
* @param {string} req.body.messages[].role - The role of the message sender (e.g., "user", "assistant").
290+
* @param {Object[]} req.body.messages[].content - An array containing message content objects.
291+
* @param {string} req.body.messages[].content[].text - The actual text of the message.
292+
*
293+
* @param {Response} res - The Express response object used to stream the generated response.
294+
*
295+
* @returns {void} Responds with a streamed text response of the AI output
296+
*
297+
* @throws {Error} If query reformulation or knowledge retrieval fails.
298+
*/
230299
export const chat = asyncHandler(async (req: Request, res: Response) => {
231300
const { messages } = req.body;
232301
const latestMessage = messages[messages.length - 1].content[0].text;
@@ -240,7 +309,7 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
240309
// Use GPT-4o to reformulate the query based on conversation history
241310
const reformulatedQuery = await reformulateQuery(
242311
latestMessage,
243-
conversationHistory.slice(-CHATBOT_MEMORY_THRESHOLD), // last K messages
312+
conversationHistory.slice(-CHATBOT_MEMORY_THRESHOLD) // last K messages
244313
);
245314
console.log(">>>> Original query:", latestMessage);
246315
console.log(">>>> Reformulated query:", reformulatedQuery);
@@ -254,15 +323,19 @@ export const chat = asyncHandler(async (req: Request, res: Response) => {
254323
if (requiresSearch) {
255324
console.log(
256325
`Query requires knowledge retrieval, searching namespaces: ${relevantNamespaces.join(
257-
", ",
258-
)}`,
326+
", "
327+
)}`
259328
);
260329

330+
const filters = includeFilters(reformulatedQuery);
331+
// console.log("Filters: ", JSON.stringify(filters))
332+
261333
// Search only relevant namespaces
262334
const searchResults = await searchSelectedNamespaces(
263335
reformulatedQuery,
264336
3,
265337
relevantNamespaces,
338+
Object.keys(filters).length === 0 ? undefined : filters
266339
);
267340
// console.log("Search Results: ", searchResults);
268341

@@ -330,15 +403,15 @@ export const testSimilaritySearch = asyncHandler(
330403
if (requiresSearch) {
331404
console.log(
332405
`Query requires knowledge retrieval, searching namespaces: ${relevantNamespaces.join(
333-
", ",
334-
)}`,
406+
", "
407+
)}`
335408
);
336409

337410
// Search only the relevant namespaces
338411
const searchResults = await searchSelectedNamespaces(
339412
message,
340413
3,
341-
relevantNamespaces,
414+
relevantNamespaces
342415
);
343416
console.log("Search Results: ", searchResults);
344417

@@ -348,11 +421,11 @@ export const testSimilaritySearch = asyncHandler(
348421
}
349422
} else {
350423
console.log(
351-
"Query does not require knowledge retrieval, skipping search",
424+
"Query does not require knowledge retrieval, skipping search"
352425
);
353426
}
354427

355428
console.log("CONTEXT: ", context);
356429
res.status(200).send(context);
357-
},
430+
}
358431
);

course-matrix/backend/src/routes/aiRouter.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,13 @@ import { authRouter } from "./authRouter";
44

55
export const aiRouter = express.Router();
66

7+
/**
8+
* @route POST /api/ai/chat
9+
* @description Handles user queries and generates responses using GPT-4o, with optional knowledge retrieval.
10+
*/
711
aiRouter.post("/chat", authRouter, chat);
12+
/**
13+
* @route POST /api/ai/test-similarity-search
14+
* @description Test vector database similarity search feature
15+
*/
816
aiRouter.post("/test-similarity-search", testSimilaritySearch);
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export const convertBreadthRequirement = (code: string) => {
2+
if (code === "ART_LIT_LANG") return "Arts, Literature and Language";
3+
else if (code === "HIS_PHIL_CUL")
4+
return "History, Philosophy and Cultural Studies";
5+
else if (code === "SOCIAL_SCI") return "Social and Behavioral Sciences";
6+
else if (code === "NAT_SCI") return "Natural Sciences";
7+
else if (code === "QUANT") return "Quantitative Reasoning";
8+
else return "";
9+
};
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export const convertYearLevel = (code: string) => {
2+
if (code === "first_year") return "1st year";
3+
else if (code === "second_year") return "2nd year";
4+
else if (code === "third_year") return "3rd year";
5+
else if (code === "fourth_year") return "4th year";
6+
else return "";
7+
};

course-matrix/backend/src/utils/embeddings.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { PineconeStore } from "@langchain/pinecone";
55
import { Pinecone } from "@pinecone-database/pinecone";
66
import config from "../config/config";
77
import path from "path";
8+
import { convertBreadthRequirement } from "./convert-breadth-requirement";
89

910
console.log("Running embeddings process...");
1011

@@ -37,6 +38,35 @@ async function processCSV(filePath: string, namespace: string) {
3738
});
3839
}
3940

41+
// Generate embeddings for courses.csv
42+
async function processCoursesCSV(filePath: string, namespace: string) {
43+
const fileName = path.basename(filePath);
44+
const loader = new CSVLoader(filePath);
45+
let docs = await loader.load();
46+
47+
docs = docs.map((doc, index) => ({
48+
...doc,
49+
metadata: {
50+
...doc.metadata,
51+
source: fileName,
52+
row: index + 1,
53+
breadth_requirement: convertBreadthRequirement(
54+
doc.pageContent.split("\n")[1].split(": ")[1]
55+
),
56+
year_level: doc.pageContent.split("\n")[10].split(": ")[1],
57+
},
58+
}));
59+
console.log("Sample doc: ", docs[0]);
60+
61+
const index = pinecone.Index(process.env.PINECONE_INDEX_NAME!);
62+
63+
// Store each row as an individual embedding
64+
await PineconeStore.fromDocuments(docs, embeddings, {
65+
pineconeIndex: index as any,
66+
namespace: namespace,
67+
});
68+
}
69+
4070
// Generate embeddings for pdfs
4171
async function processPDF(filePath: string, namespace: string) {
4272
const fileName = path.basename(filePath);
@@ -71,7 +101,7 @@ async function processPDF(filePath: string, namespace: string) {
71101
// console.log("Sample split docs: ", splitDocs.slice(0, 6))
72102

73103
console.log(
74-
`Split into ${splitDocs.length} sections by "Calendar Section:" delimiter`,
104+
`Split into ${splitDocs.length} sections by "Calendar Section:" delimiter`
75105
);
76106

77107
// Store the split documents as embeddings
@@ -98,6 +128,7 @@ async function processPDF(filePath: string, namespace: string) {
98128
// processCSV("../data/tables/offerings_winter_2026.csv", "offerings")
99129
// processCSV("../data/tables/departments.csv", "departments")
100130
// processCSV("../data/tables/courses_with_year.csv", "courses_v2")
131+
// processCoursesCSV("../data/tables/courses_with_year.csv", "courses_v3");
101132

102133
console.log("embeddings done.");
103134

0 commit comments

Comments
 (0)