|
1 | 1 | import { Router } from 'express'
|
2 |
| -import { searchRag } from '../util/rag' |
| 2 | +import { EMBED_DIM } from '../../config' |
| 3 | +import { createChunkIndex, deleteChunkIndex } from '../services/rag/chunkDb' |
| 4 | +import { RagIndex } from '../db/models' |
| 5 | +import { RequestWithUser } from '../types' |
| 6 | +import z from 'zod' |
| 7 | +import { queryRagIndex } from '../services/rag/query' |
3 | 8 |
|
4 | 9 | const router = Router()
|
5 | 10 |
|
| 11 | +const IndexCreationSchema = z.object({ |
| 12 | + name: z.string().min(1).max(100), |
| 13 | + dim: z.number().min(1024).max(1024).default(EMBED_DIM), |
| 14 | +}) |
| 15 | + |
| 16 | +router.post('/indices', async (req, res) => { |
| 17 | + const { user } = req as RequestWithUser |
| 18 | + const { name, dim } = IndexCreationSchema.parse(req.body) |
| 19 | + |
| 20 | + const ragIndex = await RagIndex.create({ |
| 21 | + userId: user.id, |
| 22 | + metadata: { |
| 23 | + name, |
| 24 | + dim, |
| 25 | + }, |
| 26 | + }) |
| 27 | + |
| 28 | + await createChunkIndex(ragIndex) |
| 29 | + |
| 30 | + res.json(ragIndex) |
| 31 | +}) |
| 32 | + |
| 33 | +router.delete('/indices/:id', async (req, res) => { |
| 34 | + const { user } = req as unknown as RequestWithUser // <- fix type |
| 35 | + const { id } = req.params |
| 36 | + |
| 37 | + const ragIndex = await RagIndex.findOne({ |
| 38 | + where: { id, userId: user.id }, |
| 39 | + }) |
| 40 | + |
| 41 | + if (!ragIndex) { |
| 42 | + res.status(404).json({ error: 'Index not found' }) |
| 43 | + return |
| 44 | + } |
| 45 | + |
| 46 | + await deleteChunkIndex(ragIndex) |
| 47 | + |
| 48 | + await ragIndex.destroy() |
| 49 | + |
| 50 | + res.json({ message: 'Index deleted' }) |
| 51 | +}) |
| 52 | + |
| 53 | +router.get('/indices', async (_req, res) => { |
| 54 | + const indices = await RagIndex.findAll() |
| 55 | + |
| 56 | + res.json(indices) |
| 57 | +}) |
| 58 | + |
| 59 | +const RagIndexQuerySchema = z.object({ |
| 60 | + query: z.string().min(1).max(1000), |
| 61 | + topK: z.number().min(1).max(100).default(5), |
| 62 | + indexId: z.number(), |
| 63 | +}) |
6 | 64 | router.post('/query', async (req, res) => {
|
7 |
| - console.log('Received request on /rag/query') |
8 |
| - console.log('Request body:', req.body) |
9 |
| - console.log('Request headers:', req.headers) |
10 |
| - console.log('Request method:', req.method) |
11 |
| - try { |
12 |
| - const prompt = req.body.prompt |
13 |
| - const answer = await searchRag(prompt) |
14 |
| - res.json(answer) |
15 |
| - } catch (error) { |
16 |
| - console.error('Error in /rag/query:', error) |
17 |
| - res.status(500).json({ error: 'Rag failed' }) |
| 65 | + const { query, topK, indexId } = RagIndexQuerySchema.parse(req.body) |
| 66 | + |
| 67 | + const ragIndex = await RagIndex.findByPk(indexId) |
| 68 | + |
| 69 | + if (!ragIndex) { |
| 70 | + res.status(404).json({ error: 'Index not found' }) |
| 71 | + return |
18 | 72 | }
|
| 73 | + |
| 74 | + const results = await queryRagIndex(ragIndex, query, topK) |
| 75 | + res.json(results) |
19 | 76 | })
|
20 | 77 |
|
21 | 78 | export default router
|
0 commit comments