Skip to content

Commit 55e612c

Browse files
committed
rag setup
1 parent 9e7d9a7 commit 55e612c

21 files changed

+685
-238
lines changed

package-lock.json

Lines changed: 10 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@
112112
"unfuck-utf8-headers-middleware": "^1.0.1",
113113
"vite": "^5.0.12",
114114
"winston": "^3.11.0",
115-
"winston-gelf-transporter": "^1.0.2"
115+
"winston-gelf-transporter": "^1.0.2",
116+
"zod": "^3.24.4"
116117
},
117118
"nodemonConfig": {
118119
"ignore": [

src/config.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ export const DEFAUL_CONTEXT_LIMIT =
2626
export const DEFAULT_RESET_CRON =
2727
process.env.DEFAULT_RESET_CRON || '0 0 1 */3 *'
2828

29+
export const EMBED_MODEL = 'text-embedding-small'
30+
export const EMBED_DIM = 1024
31+
2932
export const validModels = [
3033
{
3134
name: 'gpt-4',
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import { DataTypes } from 'sequelize'
2+
3+
import { Migration } from '../connection'
4+
5+
export const up: Migration = ({ context: queryInterface }) =>
6+
queryInterface.createTable('rag_indices', {
7+
id: {
8+
type: DataTypes.INTEGER,
9+
allowNull: false,
10+
primaryKey: true,
11+
autoIncrement: true,
12+
},
13+
user_id: {
14+
type: DataTypes.STRING,
15+
allowNull: false,
16+
},
17+
course_id: {
18+
type: DataTypes.STRING,
19+
allowNull: true,
20+
},
21+
metadata: {
22+
type: DataTypes.JSONB,
23+
allowNull: true,
24+
},
25+
created_at: {
26+
type: DataTypes.DATE,
27+
allowNull: false,
28+
},
29+
updated_at: {
30+
type: DataTypes.DATE,
31+
allowNull: false,
32+
},
33+
})
34+
35+
export const down: Migration = ({ context: queryInterface }) =>
36+
queryInterface.dropTable('rag_indices')

src/server/db/models/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import Prompt from './prompt'
55
import Enrolment from './enrolment'
66
import Responsibility from './responsibilities'
77
import Discussion from './discussion'
8+
import RagIndex from './ragIndex'
89

910
User.belongsToMany(ChatInstance, {
1011
through: UserChatInstanceUsage,
@@ -42,6 +43,10 @@ Responsibility.belongsTo(ChatInstance, { as: 'chatInstance' })
4243

4344
ChatInstance.hasMany(Responsibility, { as: 'responsibilities' })
4445

46+
User.hasMany(RagIndex, { as: 'ragIndices' })
47+
48+
RagIndex.belongsTo(User, { as: 'user' })
49+
4550
export {
4651
User,
4752
ChatInstance,
@@ -50,4 +55,5 @@ export {
5055
Enrolment,
5156
Responsibility,
5257
Discussion,
58+
RagIndex,
5359
}

src/server/db/models/ragIndex.ts

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import {
2+
Model,
3+
InferAttributes,
4+
InferCreationAttributes,
5+
CreationOptional,
6+
DataTypes,
7+
} from 'sequelize'
8+
9+
import { sequelize } from '../connection'
10+
11+
export type RagIndexMetadata = {
12+
name: string
13+
dim: number
14+
}
15+
16+
class RagIndex extends Model<
17+
InferAttributes<RagIndex>,
18+
InferCreationAttributes<RagIndex>
19+
> {
20+
declare id: CreationOptional<number>
21+
22+
declare userId: string
23+
24+
declare courseId?: string
25+
26+
declare metadata: RagIndexMetadata
27+
}
28+
29+
RagIndex.init(
30+
{
31+
id: {
32+
type: DataTypes.INTEGER,
33+
allowNull: false,
34+
primaryKey: true,
35+
autoIncrement: true,
36+
},
37+
userId: {
38+
type: DataTypes.STRING,
39+
allowNull: false,
40+
},
41+
courseId: {
42+
type: DataTypes.STRING,
43+
allowNull: true,
44+
},
45+
metadata: {
46+
type: DataTypes.JSONB,
47+
allowNull: true,
48+
},
49+
},
50+
{
51+
underscored: true,
52+
sequelize,
53+
}
54+
)
55+
56+
export default RagIndex

src/server/index.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import { connectToDatabase } from './db/connection'
2020
import seed from './db/seeders'
2121
import setupCron from './util/cron'
2222
import { updateLastRestart } from './util/lastRestart'
23-
import { initRag } from './util/rag'
2423

2524
const app = express()
2625

@@ -45,9 +44,6 @@ app.listen(PORT, async () => {
4544
await connectToDatabase()
4645
await seed()
4746
await updateLastRestart()
48-
if (RAG_ENABLED) {
49-
await initRag()
50-
}
5147
if (inProduction || inStaging) {
5248
await setupCron()
5349
}

src/server/routes/rag.ts

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,78 @@
11
import { Router } from 'express'
2-
import { searchRag } from '../util/rag'
2+
import { EMBED_DIM } from '../../config'
3+
import { createChunkIndex, deleteChunkIndex } from '../services/rag/chunkDb'
4+
import { RagIndex } from '../db/models'
5+
import { RequestWithUser } from '../types'
6+
import z from 'zod'
7+
import { queryRagIndex } from '../services/rag/query'
38

49
const router = Router()
510

11+
const IndexCreationSchema = z.object({
12+
name: z.string().min(1).max(100),
13+
dim: z.number().min(1024).max(1024).default(EMBED_DIM),
14+
})
15+
16+
router.post('/indices', async (req, res) => {
17+
const { user } = req as RequestWithUser
18+
const { name, dim } = IndexCreationSchema.parse(req.body)
19+
20+
const ragIndex = await RagIndex.create({
21+
userId: user.id,
22+
metadata: {
23+
name,
24+
dim,
25+
},
26+
})
27+
28+
await createChunkIndex(ragIndex)
29+
30+
res.json(ragIndex)
31+
})
32+
33+
router.delete('/indices/:id', async (req, res) => {
34+
const { user } = req as unknown as RequestWithUser // <- fix type
35+
const { id } = req.params
36+
37+
const ragIndex = await RagIndex.findOne({
38+
where: { id, userId: user.id },
39+
})
40+
41+
if (!ragIndex) {
42+
res.status(404).json({ error: 'Index not found' })
43+
return
44+
}
45+
46+
await deleteChunkIndex(ragIndex)
47+
48+
await ragIndex.destroy()
49+
50+
res.json({ message: 'Index deleted' })
51+
})
52+
53+
router.get('/indices', async (_req, res) => {
54+
const indices = await RagIndex.findAll()
55+
56+
res.json(indices)
57+
})
58+
59+
const RagIndexQuerySchema = z.object({
60+
query: z.string().min(1).max(1000),
61+
topK: z.number().min(1).max(100).default(5),
62+
indexId: z.number(),
63+
})
664
router.post('/query', async (req, res) => {
7-
console.log('Received request on /rag/query')
8-
console.log('Request body:', req.body)
9-
console.log('Request headers:', req.headers)
10-
console.log('Request method:', req.method)
11-
try {
12-
const prompt = req.body.prompt
13-
const answer = await searchRag(prompt)
14-
res.json(answer)
15-
} catch (error) {
16-
console.error('Error in /rag/query:', error)
17-
res.status(500).json({ error: 'Rag failed' })
65+
const { query, topK, indexId } = RagIndexQuerySchema.parse(req.body)
66+
67+
const ragIndex = await RagIndex.findByPk(indexId)
68+
69+
if (!ragIndex) {
70+
res.status(404).json({ error: 'Index not found' })
71+
return
1872
}
73+
74+
const results = await queryRagIndex(ragIndex, query, topK)
75+
res.json(results)
1976
})
2077

2178
export default router

src/server/services/rag/chunkDb.ts

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import { RagIndex } from '../../db/models'
2+
import { redisClient } from '../../util/redis'
3+
4+
export const createChunkIndex = async (ragIndex: RagIndex) => {
5+
try {
6+
await redisClient.ft.create(
7+
ragIndex.metadata.name,
8+
{
9+
metadata: {
10+
type: 'TEXT',
11+
},
12+
content: {
13+
type: 'TEXT',
14+
},
15+
embedding: {
16+
type: 'VECTOR',
17+
TYPE: 'FLOAT32',
18+
ALGORITHM: 'HNSW',
19+
DIM: ragIndex.metadata.dim,
20+
DISTANCE_METRIC: 'L2',
21+
},
22+
},
23+
{
24+
ON: 'HASH',
25+
PREFIX: `idx:${ragIndex.metadata.name}`,
26+
}
27+
)
28+
29+
console.log(`Index ${ragIndex.metadata.name} created`)
30+
} catch (err: any) {
31+
if (err.message.includes('Index already exists')) {
32+
console.log(`Index ${ragIndex.metadata.name} already exists`)
33+
} else {
34+
console.error('Error creating index', err)
35+
}
36+
}
37+
}
38+
39+
export const deleteChunkIndex = async (ragIndex: RagIndex) => {
40+
try {
41+
await redisClient.ft.dropIndex(ragIndex.metadata.name)
42+
console.log(`Index ${ragIndex.metadata.name} deleted`)
43+
} catch (err: any) {
44+
if (err.message.includes('Index not found')) {
45+
console.log(`Index ${ragIndex.metadata.name} not found`)
46+
} else {
47+
console.error('Error deleting index', err)
48+
}
49+
}
50+
}
51+
52+
export const addChunk = async (
53+
ragIndex: RagIndex,
54+
{
55+
id,
56+
metadata,
57+
content,
58+
embedding,
59+
}: {
60+
id: string
61+
metadata?: { [key: string]: any }
62+
content: string
63+
embedding: number[]
64+
}
65+
) => {
66+
const embeddingBuffer = Buffer.copyBytesFrom(new Float32Array(embedding))
67+
68+
// Check if the embedding length is correct
69+
if (embeddingBuffer.length !== 4 * ragIndex.metadata.dim) {
70+
throw new Error(
71+
`Embedding length is incorrect, got ${embeddingBuffer.length} bytes`
72+
)
73+
}
74+
75+
await redisClient.hSet(`idx:${ragIndex.metadata.name}:${id}`, {
76+
metadata: JSON.stringify(metadata || {}),
77+
content,
78+
embedding: embeddingBuffer,
79+
})
80+
81+
console.log(`Document ${id} added to index ${ragIndex.metadata.name}`)
82+
}
83+
84+
export const searchKChunks = async (
85+
ragIndex: RagIndex,
86+
embedding: number[],
87+
k: number
88+
) => {
89+
const embeddingBuffer = Buffer.copyBytesFrom(new Float32Array(embedding))
90+
91+
if (embeddingBuffer.length !== 4 * ragIndex.metadata.dim) {
92+
throw new Error(
93+
`Embedding length is incorrect, got ${embeddingBuffer.length} bytes`
94+
)
95+
}
96+
97+
const queryString = `(*)=>[KNN ${k} @embedding $vec_param AS score]`
98+
99+
const results = await redisClient.ft.search(
100+
ragIndex.metadata.name,
101+
queryString,
102+
{
103+
PARAMS: {
104+
vec_param: embeddingBuffer,
105+
},
106+
DIALECT: 2,
107+
RETURN: ['content', 'title', 'score'], // Specify the fields to return
108+
}
109+
)
110+
111+
return results as {
112+
total: number
113+
documents: {
114+
id: string
115+
value: {
116+
content: string
117+
title: string
118+
score: number
119+
metadata: string
120+
}
121+
}[]
122+
}
123+
}

0 commit comments

Comments
 (0)