Skip to content

Commit 3ae809f

Browse files
committed
Implement hybrid search for rag
1 parent 215863a commit 3ae809f

File tree

10 files changed

+282
-74
lines changed

10 files changed

+282
-74
lines changed

src/client/components/Rag.tsx

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1-
import React, { useState } from 'react'
1+
import React, { useRef, useState } from 'react'
22
import { TextField, Button, Box, Typography, Table, TableHead, TableBody, TableRow, TableCell, Paper, IconButton, Dialog, DialogTitle, styled } from '@mui/material'
33
import apiClient from '../util/apiClient'
44
import { useMutation, useQuery } from '@tanstack/react-query'
5-
import type { RagIndexAttributes } from '../../server/db/models/ragIndex'
65
import { CloudUpload, Settings } from '@mui/icons-material'
6+
import Markdown from './Banner/Markdown'
77

88
type RagResponse = {
9-
total: number
10-
documents: Array<{
11-
id: string
12-
value: {
13-
title: string
14-
content: string
15-
score: number
16-
}
17-
}>
9+
id: string
10+
value: {
11+
title: string
12+
content: string
13+
score: number
14+
}
15+
}
16+
17+
type RagIndexAttributes = {
18+
id: number
19+
metadata: {
20+
name: string
21+
dim: number
22+
}
23+
numOfChunks: number
1824
}
1925

2026
const useRagIndices = () => {
@@ -60,11 +66,18 @@ const useUploadMutation = (index: RagIndexAttributes | null) => {
6066
Array.from(files).forEach((file) => {
6167
formData.append('files', file)
6268
})
63-
const response = await apiClient.post(`/rag/indices/${index.id}/upload`, formData, {
69+
70+
// Returns a stream
71+
const response = await apiClient.put(`/rag/indices/${index.id}/upload`, formData, {
6472
headers: {
6573
'Content-Type': 'multipart/form-data',
6674
},
75+
responseType: 'stream',
76+
77+
6778
})
79+
80+
console.log('Upload response:', response.data)
6881
return response.data
6982
},
7083
})
@@ -90,21 +103,49 @@ const Rag: React.FC = () => {
90103
const [indexName, setIndexName] = useState('')
91104
const [selectedIndex, setSelectedIndex] = useState<RagIndexAttributes>(null)
92105
const [inputValue, setInputValue] = useState('')
93-
const [response, setResponse] = useState<RagResponse | null>(null)
106+
const [topK, setTopK] = useState(5)
107+
const [response, setResponse] = useState<RagResponse[] | null>(null)
94108
const uploadMutation = useUploadMutation(selectedIndex)
95109
const [modalOpen, setModalOpen] = useState(false)
110+
const progressLogs = useRef<HTMLParagraphElement>()
96111

97112
const handleSubmit = async (event: React.FormEvent) => {
98113
event.preventDefault()
99114
console.log('Form submitted with value:', inputValue)
100115
const res = await apiClient.post('/rag/query', {
101-
prompt: inputValue,
116+
query: inputValue,
117+
indexId: selectedIndex?.id,
118+
topK,
102119
})
103120
console.log('Response from server:', res.data)
104121
setResponse(res.data)
105122
setInputValue('')
106123
}
107124

125+
// Processes the upload progress stream which returns JSON objects
126+
const processUploadProgressStream = (stream) => {
127+
stream.on('data', (data: any) => {
128+
const parsedData = JSON.parse(data.toString())
129+
console.log('Parsed data:', parsedData)
130+
if (parsedData.stage === 'done') {
131+
progressLogs.current.innerHTML += `Upload completed: ${JSON.stringify(parsedData)}\n`
132+
} else if (parsedData.error) {
133+
progressLogs.current.innerHTML += `Error: ${parsedData.error}\n`
134+
} else {
135+
progressLogs.current.innerHTML += `Progress: ${JSON.stringify(parsedData)}\n`
136+
}
137+
})
138+
stream.on('end', () => {
139+
progressLogs.current.innerHTML += 'Upload stream ended.\n'
140+
})
141+
stream.on('error', (err: any) => {
142+
progressLogs.current.innerHTML += `Error: ${err}\n`
143+
})
144+
stream.on('close', () => {
145+
progressLogs.current.innerHTML += 'Upload stream closed.\n'
146+
})
147+
}
148+
108149
return (
109150
<Box sx={{ display: 'flex', gap: 2 }}>
110151
<Dialog open={!!selectedIndex && modalOpen} onClose={() => setModalOpen(false)}>
@@ -118,8 +159,8 @@ const Rag: React.FC = () => {
118159
const files = event.target.files
119160
console.log('Files selected:', files)
120161
if (files && files.length > 0) {
121-
await uploadMutation.mutateAsync(files)
122-
refetch()
162+
const stream = await uploadMutation.mutateAsync(files)
163+
processUploadProgressStream(stream)
123164
}
124165
}}
125166
multiple
@@ -139,6 +180,9 @@ const Rag: React.FC = () => {
139180
Delete Index
140181
</Button>
141182
</Box>
183+
<Box sx={{ padding: 2 }}>
184+
<p ref={progressLogs} style={{ whiteSpace: 'pre-wrap' }} />
185+
</Box>
142186
</Dialog>
143187
<Box>
144188
<Typography variant="h4" mb="1rem">
@@ -174,13 +218,15 @@ const Rag: React.FC = () => {
174218
<TableCell>ID</TableCell>
175219
<TableCell>Name</TableCell>
176220
<TableCell>Dim</TableCell>
221+
<TableCell>Num chunks</TableCell>
177222
</TableRow>
178223
</TableHead>
179224
<TableBody>
180225
<TableRow>
181226
<TableCell>{index.id}</TableCell>
182227
<TableCell>{index.metadata.name}</TableCell>
183228
<TableCell>{index.metadata.dim}</TableCell>
229+
<TableCell>{index.numOfChunks}</TableCell>
184230
</TableRow>
185231
</TableBody>
186232
</Table>
@@ -205,35 +251,25 @@ const Rag: React.FC = () => {
205251
display: 'flex',
206252
flexDirection: 'column',
207253
gap: 2,
208-
width: '300px',
209254
margin: '0 auto',
210255
}}
211256
>
212257
<TextField label="Enter text" variant="outlined" value={inputValue} onChange={(e) => setInputValue(e.target.value)} fullWidth />
213-
<Button type="submit" variant="contained" color="primary">
214-
Submit
258+
<TextField label="top k" variant="outlined" type="number" value={topK} onChange={(e) => setTopK(parseInt(e.target.value, 10))} fullWidth />
259+
<Button type="submit" variant="contained" color="primary" disabled={!inputValue || !selectedIndex}>
260+
Search
215261
</Button>
216262
{response && (
217-
<Box
218-
sx={{
219-
marginTop: 2,
220-
padding: 2,
221-
border: '1px solid #ccc',
222-
borderRadius: '4px',
223-
}}
224-
>
263+
<Box mt={2}>
225264
<Typography variant="h6">Response:</Typography>
226-
<Typography variant="body1">Total: {response.total}</Typography>
227-
{response.documents.map((doc) => (
228-
<Box key={doc.id} sx={{ marginBottom: 1 }}>
229-
<Typography variant="subtitle1">{doc.value.title}</Typography>
230-
<Typography variant="body2">{doc.value.content}</Typography>
265+
{response.map((doc) => (
266+
<Paper key={doc.id} sx={{ marginBottom: 2, p: 1 }} elevation={2}>
231267
<Typography variant="caption">Score: {doc.value.score}</Typography>
232-
</Box>
268+
<Markdown>{doc.value.content}</Markdown>
269+
</Paper>
233270
))}
234271
</Box>
235272
)}
236-
ss
237273
</Box>
238274
</Box>
239275
)

src/server/routes/rag.ts

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
import { Router } from 'express'
1+
import { NextFunction, Request, Response, Router } from 'express'
22
import { EMBED_DIM, EMBED_MODEL } from '../../config'
3-
import { createChunkIndex, deleteChunkIndex } from '../services/rag/chunkDb'
3+
import { createChunkIndex, deleteChunkIndex, getNumberOfChunks } from '../services/rag/chunkDb'
44
import { RagIndex } from '../db/models'
55
import { RequestWithUser } from '../types'
66
import z from 'zod'
77
import { queryRagIndex } from '../services/rag/query'
88
import { ingestionPipeline } from '../services/rag/ingestion/pipeline'
99
import { getAzureOpenAIClient } from '../util/azure'
1010
import multer from 'multer'
11-
import { mkdir } from 'fs/promises'
11+
import { mkdir, rm, stat } from 'fs/promises'
1212

1313
const router = Router()
1414

15+
const UPLOAD_DIR = 'uploads/rag'
16+
1517
const IndexCreationSchema = z.object({
1618
name: z.string().min(1).max(100),
1719
dim: z.number().min(1024).max(1024).default(EMBED_DIM),
@@ -31,6 +33,8 @@ router.post('/indices', async (req, res) => {
3133

3234
await createChunkIndex(ragIndex)
3335

36+
// Create upload directory for this index
37+
3438
res.json(ragIndex)
3539
})
3640

@@ -49,6 +53,14 @@ router.delete('/indices/:id', async (req, res) => {
4953

5054
await deleteChunkIndex(ragIndex)
5155

56+
const uploadPath = `${UPLOAD_DIR}/${id}`
57+
try {
58+
await rm(uploadPath, { recursive: true, force: true })
59+
console.log(`Upload directory ${uploadPath} deleted`)
60+
} catch (error) {
61+
console.warn(`Upload directory ${uploadPath} not found, nothing to delete --- `, error)
62+
}
63+
5264
await ragIndex.destroy()
5365

5466
res.json({ message: 'Index deleted' })
@@ -57,27 +69,55 @@ router.delete('/indices/:id', async (req, res) => {
5769
router.get('/indices', async (_req, res) => {
5870
const indices = await RagIndex.findAll()
5971

60-
res.json(indices)
72+
const indicesWithMetadata = await Promise.all(
73+
indices.map(async (index) => {
74+
const numOfChunks = await getNumberOfChunks(index)
75+
76+
return {
77+
...index.toJSON(),
78+
numOfChunks,
79+
}
80+
}),
81+
)
82+
83+
res.json(indicesWithMetadata)
6184
})
6285

86+
const IndexIdSchema = z.coerce.number().min(1)
87+
6388
const upload = multer({
6489
storage: multer.diskStorage({
6590
destination: async (req, file, cb) => {
66-
const uploadPath = `uploads/rag/${req.params.id}`
67-
// Create the directory if it doesn't exist
68-
await mkdir(uploadPath, { recursive: true })
91+
const id = IndexIdSchema.parse(req.params.id)
92+
const uploadPath = `${UPLOAD_DIR}/${id}`
6993
cb(null, uploadPath)
7094
},
95+
filename: (req, file, cb) => {
96+
const uniqueFilename = file.originalname
97+
cb(null, uniqueFilename)
98+
},
7199
}),
72100
limits: {
73101
fileSize: 10 * 1024 * 1024, // 10 MB
74102
},
75103
})
76104
const uploadMiddleware = upload.array('files')
77105

78-
router.post('/indices/:id/upload', uploadMiddleware, async (req, res) => {
106+
const indexUploadDirMiddleware = async (req: Request, _res: Response, next: NextFunction) => {
107+
const id = IndexIdSchema.parse(req.params.id)
108+
const uploadPath = `${UPLOAD_DIR}/${id}`
109+
try {
110+
await stat(uploadPath)
111+
} catch (_error) {
112+
console.warn(`RAG upload dir not found, creating ${uploadPath} --- `)
113+
await mkdir(uploadPath, { recursive: true })
114+
}
115+
next()
116+
}
117+
118+
router.put('/indices/:id/upload', [indexUploadDirMiddleware, uploadMiddleware], async (req, res) => {
79119
const { user } = req as unknown as RequestWithUser
80-
const { id } = req.params
120+
const id = IndexIdSchema.parse(req.params.id)
81121

82122
const ragIndex = await RagIndex.findOne({
83123
where: { id, userId: user.id },
@@ -94,9 +134,7 @@ router.post('/indices/:id/upload', uploadMiddleware, async (req, res) => {
94134

95135
const openAiClient = getAzureOpenAIClient(EMBED_MODEL)
96136

97-
res.json({ message: 'Ingestion started' })
98-
99-
await ingestionPipeline(openAiClient, `uploads/rag/${req.params.id}`, ragIndex)
137+
await ingestionPipeline(openAiClient, `uploads/rag/${id}`, ragIndex)
100138
})
101139

102140
const RagIndexQuerySchema = z.object({

0 commit comments

Comments
 (0)