Skip to content

Commit cdc026b

Browse files
committed
Rag search ui improvements
1 parent adc1038 commit cdc026b

File tree

10 files changed

+166
-53
lines changed

10 files changed

+166
-53
lines changed

src/client/components/Rag/Rag.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ const Rag: React.FC = () => {
8484
<TableRow>
8585
<TableCell>ID</TableCell>
8686
<TableCell>Name</TableCell>
87-
<TableCell>Vector Dimensions</TableCell>
87+
<TableCell>Language</TableCell>
8888
<TableCell>Number of files</TableCell>
8989
</TableRow>
9090
</TableHead>
9191
<TableBody>
9292
<TableRow>
9393
<TableCell>{index.id}</TableCell>
9494
<TableCell>{index.metadata?.name}</TableCell>
95-
<TableCell>{index.metadata?.dim}</TableCell>
95+
<TableCell>{index.metadata?.language}</TableCell>
9696
<TableCell>{index.ragFileCount}</TableCell>
9797
</TableRow>
9898
</TableBody>
Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,102 @@
11
import { useState } from 'react'
22
import { RagIndexAttributes } from '../../../shared/types'
3-
import { RagChunk } from '../../../shared/rag'
3+
import { RagChunk, SearchInputParams, SearchParams } from '../../../shared/rag'
44
import apiClient from '../../util/apiClient'
5+
import { Box, Checkbox, Collapse, Fade, FormControl, FormControlLabel, Grow, LinearProgress, TextField, Typography, Zoom } from '@mui/material'
6+
import { OutlineButtonBlue } from '../ChatV2/general/Buttons'
7+
import { amber, green, blue } from '@mui/material/colors'
8+
9+
const TimeLineColors = [blue[200], amber[300], green[300]]
510

611
export const Search = ({ ragIndex }: { ragIndex: RagIndexAttributes }) => {
712
const [query, setQuery] = useState('')
8-
const [results, setResults] = useState<RagChunk[]>([])
13+
const [vector, setVector] = useState(true)
14+
const [ft, setFt] = useState(true)
15+
const [rerank, setRerank] = useState(true)
16+
const [isLoading, setIsLoading] = useState(false)
17+
const [results, setResults] = useState<{ results: RagChunk[]; timings: Record<string, number> }>()
918

1019
const handleSubmit = async (event: React.FormEvent<HTMLFormElement>) => {
1120
event.preventDefault()
12-
const response = await apiClient.post<RagChunk[]>(`/rag/indices/${ragIndex.id}/search`, {
21+
const searchParams: SearchInputParams = {
1322
query,
14-
})
23+
vector,
24+
ft,
25+
rerank,
26+
}
27+
setIsLoading(true)
28+
setResults(undefined)
29+
const response = await apiClient.post<{ results: RagChunk[]; timings: Record<string, number> }>(`/rag/indices/${ragIndex.id}/search`, searchParams)
1530

16-
setResults(response.data ?? [])
31+
setResults(response.data)
32+
setIsLoading(false)
1733
}
1834

1935
const handleInputChange = (event: React.ChangeEvent<HTMLInputElement>) => {
2036
setQuery(event.target.value)
2137
}
2238

39+
const totalTime = Object.values(results?.timings ?? {}).reduce((acc, curr) => acc + curr, 0)
40+
2341
return (
24-
<form onSubmit={handleSubmit}>
25-
<input type="text" value={query} onChange={handleInputChange} />
26-
<button type="submit">Search</button>
27-
<ul>
28-
{results.map((chunk) => (
29-
<li key={chunk.id}>{chunk.content}</li>
30-
))}
31-
</ul>
32-
</form>
42+
<Box my="2rem">
43+
<form onSubmit={handleSubmit}>
44+
<FormControl>
45+
<TextField type="text" value={query} onChange={handleInputChange} label="Search Query" />
46+
<FormControlLabel control={<Checkbox checked={vector} onChange={(e) => setVector(e.target.checked)} />} label="Use semantic search" />
47+
<FormControlLabel control={<Checkbox checked={ft} onChange={(e) => setFt(e.target.checked)} />} label="Use keyword search" />
48+
<FormControlLabel control={<Checkbox checked={rerank} onChange={(e) => setRerank(e.target.checked)} />} label="Use reranking" />
49+
</FormControl>
50+
<OutlineButtonBlue type="submit">Search</OutlineButtonBlue>
51+
</form>
52+
{isLoading && <LinearProgress />}
53+
<Fade in={!!results?.timings}>
54+
{results?.timings ? (
55+
<Box my="2rem">
56+
<Typography variant="h6">Timings</Typography>
57+
<Box display="flex" width="70vw">
58+
{Object.entries(results?.timings ?? {}).map(([key, value], idx) => (
59+
<Zoom
60+
in={true}
61+
key={key}
62+
style={{ width: `${(value / totalTime) * 100}%`, minWidth: '1rem', whiteSpace: 'nowrap' }}
63+
timeout={{ enter: 1000 + idx * 1000 }}
64+
>
65+
<div>
66+
{key}: {value} ms
67+
<Box width="100%" height="1rem" bgcolor={TimeLineColors[idx % TimeLineColors.length]} border="1px solid black" borderRadius="1rem" />
68+
</div>
69+
</Zoom>
70+
))}
71+
</Box>
72+
</Box>
73+
) : (
74+
<div />
75+
)}
76+
</Fade>
77+
<Fade in={!!results?.results}>
78+
{results?.results ? (
79+
<Box my="1rem">
80+
<Typography variant="h6" mb="1rem">
81+
Results
82+
</Typography>
83+
<Box>
84+
{results?.results?.map((chunk) => (
85+
<Box key={chunk.id} my="1rem">
86+
<Typography variant="subtitle2" color="text.secondary">
87+
Source: {chunk.metadata?.ragFileName ?? 'unknown'}
88+
</Typography>
89+
<Typography whiteSpace="pre-line" variant="body2">
90+
{chunk.content}
91+
</Typography>
92+
</Box>
93+
))}
94+
</Box>
95+
</Box>
96+
) : (
97+
<div />
98+
)}
99+
</Fade>
100+
</Box>
33101
)
34102
}

src/server/routes/rag/rag.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ router.post('/indices', async (req, res) => {
6868
userId: user.id,
6969
metadata: {
7070
name,
71-
dim,
72-
// azureVectorStoreId: vectorStore.id,
73-
ragIndexFilterValue: randomUUID(),
7471
language,
7572
},
7673
})

src/server/routes/rag/ragIndex.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { ApplicationError } from '../../util/ApplicationError'
99
import { ingestRagFiles } from '../../services/rag/ingestion'
1010
import { search } from '../../services/rag/search'
1111
import { getRedisVectorStore } from '../../services/rag/vectorStore'
12+
import { SearchSchema } from '../../../shared/rag'
1213

1314
const ragIndexRouter = Router()
1415

@@ -213,18 +214,14 @@ ragIndexRouter.post('/upload', [indexUploadDirMiddleware, uploadMiddleware], asy
213214
res.json({ message: 'Files uploaded successfully' })
214215
})
215216

216-
const RagIndexSearchSchema = z.object({
217-
query: z.string().min(1).max(1000),
218-
})
219-
220217
ragIndexRouter.post('/search', async (req, res) => {
221218
const ragIndexRequest = req as unknown as RagIndexRequest
222219
const { ragIndex } = ragIndexRequest
223-
const { query } = RagIndexSearchSchema.parse(req.body)
220+
const searchParams = SearchSchema.parse(req.body)
224221

225-
const results = await search(query, ragIndex)
222+
const { results, timings } = await search(ragIndex, searchParams)
226223

227-
res.json(results)
224+
res.json({ results, timings })
228225
})
229226

230227
export default ragIndexRouter

src/server/routes/testUtils.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ router.post('/reset-test-data', async (req, res) => {
6666
defaults: {
6767
userId,
6868
metadata: {
69-
ragIndexFilterValue: 'mock',
7069
name: `rag-${testUserIdx}`,
71-
azureVectorStoreId: 'mock',
70+
language: 'English',
7271
},
7372
},
7473
})

src/server/services/rag/ingestion.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ export const ingestRagFiles = async (ragIndex: RagIndex) => {
4242

4343
chunkDocuments.forEach((chunkDocument, idx) => {
4444
chunkDocument.id = `ragIndex-${ragFile.ragIndexId}-${ragFile.filename}-${idx}`
45+
chunkDocument.metadata = {
46+
...chunkDocument.metadata,
47+
ragFileName: ragFile.filename,
48+
}
4549
})
4650

4751
// console.log(await redisClient.ft.info(vectorStore.indexName))

src/server/services/rag/search.ts

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,68 @@
1-
import { RagChunk } from '../../../shared/rag'
1+
import { RagChunk, SearchParams } from '../../../shared/rag'
22
import type { RagIndex } from '../../db/models'
33
import { getRedisVectorStore } from './vectorStore'
44
import { FTSearchRetriever } from './retrievers'
55
import { EnsembleRetriever } from 'langchain/retrievers/ensemble'
66
import { BM25Retriever } from '@langchain/community/retrievers/bm25'
7+
import { BaseRetriever } from '@langchain/core/retrievers'
8+
9+
export const search = async (ragIndex: RagIndex, searchParams: SearchParams): Promise<{ results: RagChunk[]; timings: Record<string, number> }> => {
10+
console.log('Searching', ragIndex.metadata.name, 'for query:', searchParams.query)
11+
const timings: Record<string, number> = {}
712

8-
export const search = async (query: string, ragIndex: RagIndex): Promise<RagChunk[]> => {
9-
console.log('Searching', ragIndex.metadata.name, 'for query:', query)
1013
const vectorStore = getRedisVectorStore(ragIndex.id)
1114

12-
const vectorstoreRetriever = vectorStore.asRetriever(8)
15+
const vectorstoreRetriever = vectorStore.asRetriever(searchParams.vectorK)
1316
const ftSearchRetriever = new FTSearchRetriever(vectorStore.indexName)
1417

18+
const retrievers: BaseRetriever[] = []
19+
const weights: number[] = []
20+
21+
if (searchParams.vector) {
22+
retrievers.push(vectorstoreRetriever)
23+
weights.push(0.3)
24+
}
25+
26+
if (searchParams.ft) {
27+
retrievers.push(ftSearchRetriever)
28+
weights.push(0.7)
29+
}
30+
1531
const retriever = new EnsembleRetriever({
16-
retrievers: [vectorstoreRetriever, ftSearchRetriever],
17-
weights: [0.3, 0.7],
32+
retrievers,
33+
weights,
1834
})
1935

20-
const results0 = await retriever.invoke(query)
36+
timings.search = Date.now()
37+
let results = await retriever.invoke(searchParams.query)
38+
timings.search = Date.now() - timings.search
2139

22-
const reranker = BM25Retriever.fromDocuments(results0, { k: 5 })
40+
if (searchParams.rerank) {
41+
timings.rerank = Date.now()
42+
const reranker = BM25Retriever.fromDocuments(results, { k: searchParams.rerankK })
43+
results = await reranker.invoke(searchParams.query)
44+
timings.rerank = Date.now() - timings.rerank
45+
}
2346

24-
const results = await reranker.invoke(query)
47+
return {
48+
results: results.map((doc) => ({
49+
id: doc.id,
50+
content: doc.pageContent,
51+
metadata: parseMetadata(doc.metadata),
52+
})),
53+
timings,
54+
}
55+
}
2556

26-
return results.map((doc) => ({
27-
id: doc.id,
28-
content: doc.pageContent,
29-
metadata: doc.metadata,
30-
}))
57+
const parseMetadata = (metadata: Record<string, any> | string) => {
58+
if (typeof metadata === 'string') {
59+
const unescaped = metadata.replace(/\\/g, '')
60+
try {
61+
return JSON.parse(unescaped)
62+
} catch (error) {
63+
console.error('Error parsing metadata:', unescaped, error)
64+
return {}
65+
}
66+
}
67+
return metadata
3168
}

src/server/services/rag/searchTool.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@ import { tool } from '@langchain/core/tools'
22
import { z } from 'zod/v4'
33
import { RagIndex } from '../../db/models'
44
import { search } from './search'
5+
import { SearchSchema } from '../../../shared/rag'
56

67
export const getRagIndexSearchTool = (ragIndex: RagIndex) =>
78
tool(
89
async ({ query }: { query: string }) => {
910
console.log('Search tool invoked with query:', query)
10-
const documents = await search(query, ragIndex)
11+
const { results: documents } = await search(ragIndex, SearchSchema.parse({ query }))
1112
// With responseFormat: content_and_artifact, return content and artifact like this:
1213
return [documents.map((doc) => doc.content).join('\n\n'), documents]
1314
},
1415
{
1516
name: `document_search`, // Gotcha: function name must match '^[a-zA-Z0-9_\.-]+$' at least in AzureOpenAI. This name must satisfy the name in ChatToolDef type
16-
description: `Search documents in the materials (titled '${ragIndex.metadata.name}')`,
17+
description: `Search documents in the materials (titled '${ragIndex.metadata.name}'). Prefer ${ragIndex.metadata.language}, which is the language used in the documents.`,
1718
schema: z.object({
1819
query: z.string().describe('the query to search for'),
1920
}),

src/shared/rag.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
1+
import z from 'zod/v4'
2+
13
export type RagChunk = {
24
id?: string
35
content: string
4-
metadata: Record<string, any>
6+
metadata: {
7+
ragFileName: string
8+
[key: string]: any
9+
}
510
score?: number
611
}
12+
13+
export const SearchSchema = z.object({
14+
query: z.string().min(1).max(1000),
15+
ft: z.boolean().default(true),
16+
vector: z.boolean().default(true),
17+
vectorK: z.number().min(1).max(20).default(8),
18+
rerank: z.boolean().default(true),
19+
rerankK: z.number().min(1).max(20).default(5),
20+
})
21+
22+
export type SearchInputParams = z.input<typeof SearchSchema>
23+
export type SearchParams = z.infer<typeof SearchSchema>

src/shared/types.ts

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
1-
import type { ResponseFileSearchToolCall } from 'openai/resources/responses/responses'
2-
import type { VectorStoreFile } from 'openai/resources/vector-stores/files'
31
import type { IngestionPipelineStageKey } from './constants'
4-
import { ChatToolDef } from './tools'
2+
import type { ChatToolDef } from './tools'
53

64
export type RagIndexMetadata = {
75
name: string
8-
dim?: number
9-
azureVectorStoreId?: string
10-
ragIndexFilterValue: string
116
instructions?: string
127
language?: 'Finnish' | 'English'
138
}
149

1510
export type RagFileMetadata = {
16-
chunkingStrategy?: NonNullable<VectorStoreFile['chunking_strategy']>['type']
17-
vectorStoreFileId?: string
1811
usageBytes?: number
1912
}
2013

0 commit comments

Comments
 (0)