Skip to content

Commit 3fb8d2f

Browse files
committed
make work with local embedding & betterments
1 parent cebc3b8 commit 3fb8d2f

File tree

14 files changed

+196
-118
lines changed

14 files changed

+196
-118
lines changed
Lines changed: 122 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,97 @@
1-
import { Box, Typography } from '@mui/material'
1+
import { Box, Table, TableBody, TableCell, TableHead, TableRow, Typography } from '@mui/material'
2+
import { orderBy } from 'lodash'
23
import { useEffect, useReducer } from 'react'
4+
import { IngestionPipelineStageKey, IngestionPipelineStageKeys, IngestionPipelineStages } from '../../../shared/constants'
35

46
type ProgressEvent = {
57
stage: string
6-
item?: string
8+
items?: string[]
79
done?: boolean
810
error?: string
911
}
1012

11-
type ProgressState = {
12-
[fileName: string]: {
13-
[stage: string]: {
14-
count: number
15-
done: boolean
13+
type ProgressState = Record<
14+
IngestionPipelineStageKey,
15+
{
16+
count: number
17+
done: boolean
18+
error: boolean
19+
files: {
20+
[fileName: string]: {
21+
count: number
22+
done: boolean
23+
error: boolean
24+
}
1625
}
1726
}
18-
}
27+
>
28+
29+
const getInitialProgressState = () =>
30+
Object.fromEntries(
31+
IngestionPipelineStageKeys.map((stage) => [
32+
stage,
33+
{
34+
count: 0,
35+
done: false,
36+
error: false,
37+
files: {},
38+
},
39+
]),
40+
) as ProgressState
1941

2042
type Action = { type: 'UPDATE'; payload: ProgressEvent } | { type: 'RESET' }
2143

2244
const progressReducer = (state: ProgressState, action: Action): ProgressState => {
2345
switch (action.type) {
2446
case 'UPDATE': {
25-
const { stage, item, done } = action.payload
47+
const { stage, items, done, error } = action.payload
48+
2649
if (done) {
27-
// mark all items at this stage as done
28-
const newState = { ...state }
29-
for (const file in newState) {
30-
if (newState[file][stage]) {
31-
newState[file][stage].done = true
32-
}
50+
return {
51+
...state,
52+
[stage]: {
53+
...state[stage],
54+
done: true,
55+
error: !!error || false,
56+
},
3357
}
34-
return newState
3558
}
36-
if (!item) return state
3759

38-
const fileStages = state[item] || {}
39-
const stageData = fileStages[stage] || { count: 0, done: false }
60+
if (!items) return state
61+
62+
const updatedFiles = items.reduce(
63+
(acc, fileName) => {
64+
if (!acc[fileName]) {
65+
acc[fileName] = { count: 0, done: done || false, error: !!error || false }
66+
}
67+
acc[fileName].count += 1
68+
return acc
69+
},
70+
{ ...state[stage].files },
71+
)
4072

4173
return {
4274
...state,
43-
[item]: {
44-
...fileStages,
45-
[stage]: {
46-
count: stageData.count + 1,
47-
done: stageData.done || false,
48-
},
75+
[stage]: {
76+
...state[stage],
77+
done: done || state[stage]?.done,
78+
error: !!error || state[stage]?.error,
79+
count: state[stage]?.count + items.length,
80+
files: updatedFiles,
4981
},
5082
}
5183
}
5284

5385
case 'RESET':
54-
return {}
86+
return getInitialProgressState()
5587

5688
default:
5789
return state
5890
}
5991
}
6092

61-
export const ProgressReporter: React.FC<{ stream: ReadableStream | null }> = ({ stream }) => {
62-
const [progress, dispatch] = useReducer(progressReducer, {})
93+
export const ProgressReporter: React.FC<{ filenames: string[]; stream: ReadableStream | null; onError: () => void }> = ({ filenames, stream, onError }) => {
94+
const [progress, dispatch] = useReducer(progressReducer, getInitialProgressState())
6395

6496
useEffect(() => {
6597
let reader: ReadableStreamDefaultReader<Uint8Array> | null = null
@@ -69,6 +101,8 @@ export const ProgressReporter: React.FC<{ stream: ReadableStream | null }> = ({
69101
reader = stream.getReader()
70102
} catch (error) {
71103
console.error('Error getting reader from stream:', error)
104+
dispatch({ type: 'RESET' })
105+
onError()
72106
return
73107
}
74108

@@ -90,36 +124,79 @@ export const ProgressReporter: React.FC<{ stream: ReadableStream | null }> = ({
90124
dispatch({ type: 'UPDATE', payload: jsonChunk })
91125
} catch (err) {
92126
console.error('Invalid chunk:', line, err)
127+
onError()
93128
}
94129
}
95130
}
96131
}
97132
}
98133

99134
readStream()
135+
} else {
136+
dispatch({ type: 'RESET' })
137+
}
100138

101-
return () => {
102-
if (reader) {
103-
reader.releaseLock()
104-
}
139+
return () => {
140+
if (reader) {
141+
reader.releaseLock()
142+
dispatch({ type: 'RESET' })
105143
}
106144
}
107145
}, [stream])
108146

147+
console.log('Progress state:', progress)
148+
109149
return (
110-
<Box>
111-
{Object.entries(progress).map(([file, stages]) => (
112-
<div key={file} className="border p-2 rounded shadow">
113-
<h3 className="font-bold text-lg">{file}</h3>
114-
<ul className="ml-4 list-disc">
115-
{Object.entries(stages).map(([stage, { count, done }]) => (
116-
<li key={stage}>
117-
<span className="font-medium">{stage}:</span> {done ? "✅ Done" : `🔄 Processing (${count}x)`}
118-
</li>
150+
<Table size="small">
151+
<TableHead>
152+
<TableRow>
153+
<TableCell>File</TableCell>
154+
{IngestionPipelineStageKeys.map((stage) => (
155+
<TableCell key={stage}>
156+
<Typography variant="body2">{IngestionPipelineStages[stage].name}</Typography>
157+
<Typography variant="caption" color="textSecondary">
158+
{progress[stage].count}{' '}
159+
</Typography>
160+
<Typography variant="caption" color="textSecondary">
161+
{progress[stage].done ? 'Done' : progress[stage].error ? 'Error' : 'In Progress'}
162+
</Typography>
163+
</TableCell>
164+
))}
165+
</TableRow>
166+
</TableHead>
167+
<TableBody>
168+
{filenames.map((filename) => (
169+
<TableRow key={filename}>
170+
<TableCell component="th" scope="row">
171+
{filename}
172+
</TableCell>
173+
{IngestionPipelineStageKeys.map((stage) => (
174+
<TableCell
175+
key={stage}
176+
sx={{
177+
transition: 'background-color 0.3s',
178+
backgroundColor: progress[stage]?.error
179+
? 'error.light'
180+
: progress[stage]?.done
181+
? 'success.light'
182+
: progress[stage]?.files[filename]?.error
183+
? 'error.light'
184+
: progress[stage]?.files[filename]?.count
185+
? 'info.light'
186+
: 'inherit',
187+
}}
188+
>
189+
<Box display="flex" gap={2}>
190+
<Typography variant="body2">{progress[stage].files[filename]?.count > 1 || ''}</Typography>
191+
<Typography variant="caption" color="textSecondary">
192+
{progress[stage]?.files[filename]?.done ? 'Done' : progress[stage].files[filename]?.error ? 'Error' : ''}
193+
</Typography>
194+
</Box>
195+
</TableCell>
119196
))}
120-
</ul>
121-
</div>
122-
))}
123-
</Box>
197+
</TableRow>
198+
))}
199+
</TableBody>
200+
</Table>
124201
)
125202
}

src/client/components/Rag/Rag.tsx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ const Rag: React.FC = () => {
106106
const uploadMutation = useUploadMutation(selectedIndex)
107107
const [modalOpen, setModalOpen] = useState(false)
108108
const [stream, setStream] = useState<ReadableStream | null>(null)
109+
const [filenames, setFilenames] = useState<string[]>([])
109110

110111
const handleSubmit = async (event: React.FormEvent) => {
111112
event.preventDefault()
@@ -128,9 +129,13 @@ const Rag: React.FC = () => {
128129
setInputValue('')
129130
}
130131

132+
const handleUploadError = () => {
133+
setStream(null)
134+
}
135+
131136
return (
132137
<Box sx={{ display: 'flex', gap: 2 }}>
133-
<Dialog open={!!selectedIndex && modalOpen} onClose={() => setModalOpen(false)}>
138+
<Dialog open={!!selectedIndex && modalOpen} onClose={() => setModalOpen(false)} fullWidth maxWidth="md">
134139
<DialogTitle>Edit {selectedIndex?.metadata?.name}</DialogTitle>
135140
<Box sx={{ padding: 2 }}>
136141
<Box sx={{ display: 'flex', gap: 2 }}>
@@ -143,6 +148,7 @@ const Rag: React.FC = () => {
143148
console.log('Files selected:', files)
144149
if (files && files.length > 0) {
145150
const stream = await uploadMutation.mutateAsync(files)
151+
setFilenames(Array.from(files).map((file) => file.name))
146152
setStream(stream)
147153
}
148154
}}
@@ -164,7 +170,7 @@ const Rag: React.FC = () => {
164170
</Button>
165171
</Box>
166172
<Box mt={2}>
167-
<ProgressReporter stream={stream} />
173+
<ProgressReporter filenames={filenames} stream={stream} onError={handleUploadError} />
168174
</Box>
169175
</Box>
170176
</Dialog>

src/config.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export const PUBLIC_URL = process.env.PUBLIC_URL || ''
1212

1313
export const UPDATER_CRON_ENABLED = process.env.UPDATER_CRON_ENABLED === 'true'
1414

15-
export const OLLAMA_HOST = process.env.OLLAMA_HOST || 'http://ollama:11434'
15+
export const OLLAMA_URL = process.env.OLLAMA_HOST || 'http://ollama:11434/v1/'
1616
export const RAG_ENABLED = process.env.RAG_ENABLED === 'true'
1717

1818
export const DEFAULT_TOKEN_LIMIT = Number(process.env.DEFAULT_TOKEN_LIMIT) || 150_000
@@ -24,7 +24,7 @@ export const DEFAUL_CONTEXT_LIMIT = Number(process.env.DEFAUL_CONTEXT_LIMIT) ||
2424
export const DEFAULT_RESET_CRON = process.env.DEFAULT_RESET_CRON || '0 0 1 */3 *'
2525

2626
export const EMBED_MODEL = process.env.EMBED_MODEL ?? 'text-embedding-small'
27-
export const EMBED_DIM = 1024
27+
export const EMBED_DIM = process.env.EMBED_DIM ? Number(process.env.EMBED_DIM) : 1024
2828

2929
export const validModels = [
3030
{

src/server/routes/rag.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ import { getAzureOpenAIClient } from '../util/azure'
1010
import multer from 'multer'
1111
import { mkdir, rm, stat } from 'fs/promises'
1212
import { Readable } from 'stream'
13+
import { getOllamaOpenAIClient } from '../util/ollama'
1314

1415
const router = Router()
1516

1617
const UPLOAD_DIR = 'uploads/rag'
1718

1819
const IndexCreationSchema = z.object({
1920
name: z.string().min(1).max(100),
20-
dim: z.number().min(1024).max(1024).default(EMBED_DIM),
21+
dim: z.number().min(EMBED_DIM).max(EMBED_DIM).default(EMBED_DIM),
2122
})
2223

2324
router.post('/indices', async (req, res) => {
@@ -136,7 +137,7 @@ router.post('/indices/:id/upload', [indexUploadDirMiddleware, uploadMiddleware],
136137
res.setHeader('Content-Type', 'application/x-ndjson')
137138
res.setHeader('Transfer-Encoding', 'chunked')
138139

139-
const openAiClient = getAzureOpenAIClient(EMBED_MODEL)
140+
const openAiClient = getOllamaOpenAIClient() // getAzureOpenAIClient(EMBED_MODEL)
140141

141142
const progressReporter = await ingestionPipeline(openAiClient, `uploads/rag/${id}`, ragIndex)
142143

src/server/services/rag/ingestion/chunker.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export class Chunker extends Transform {
3535
}),
3636
)
3737

38-
this.progressReporter.reportProgress(data.fileName)
38+
this.progressReporter.reportProgress([data.fileName])
3939

4040
callback()
4141
}

src/server/services/rag/ingestion/embedder.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export class Embedder extends Transform {
4747
}
4848

4949
private async embedBatch(callback: (error?: Error | null) => void) {
50+
const currentBatchFilenames = this.currentBatch.map((chunk) => chunk.metadata.filename)
5051
try {
5152
const chunkContents = this.currentBatch.map((chunk) => chunk.content.join('\n'))
5253
const startedAt = Date.now()
@@ -65,15 +66,16 @@ export class Embedder extends Transform {
6566
// Save embedded chunk to cache
6667
const path = `${this.cachePath}/${embeddedChunk.id}.json`
6768
await writeFile(path, JSON.stringify(embeddedChunk, null, 2), 'utf-8')
68-
this.progressReporter.reportProgress(embeddedChunk.metadata.filename)
6969
})
70+
this.progressReporter.reportProgress(currentBatchFilenames)
7071

7172
// Reset the current batch
7273
this.currentBatch = []
7374

7475
callback()
7576
} catch (error) {
76-
console.error(`Error saving chunk to cache: ${error}`)
77+
console.error(`Embedding stage ${error}`)
78+
this.progressReporter.reportError('Embedding chunk failed', currentBatchFilenames)
7779
callback(error as Error)
7880
}
7981
}

src/server/services/rag/ingestion/loader.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ export class FileLoader extends Readable {
6363
this.progressReporter.reportDone()
6464
} else {
6565
this.push(file.value)
66-
this.progressReporter.reportProgress(file.value.fileName)
66+
this.progressReporter.reportProgress([file.value.fileName])
6767
}
6868
})
6969
}

src/server/services/rag/ingestion/pipeline.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import { RedisStorer } from './storer.ts'
77
import type OpenAI from 'openai'
88
import RagIndex from '../../../db/models/ragIndex.ts'
99
import { TextExtractor } from './textExtractor.ts'
10-
import { Readable } from 'node:stream'
1110
import { ProgressReporter } from './progressReporter.ts'
11+
import { IngestionPipelineStageKeys } from '../../../../shared/constants.ts'
1212

1313
// Pipeline debug cache in pipeline/
1414
// Check if exists, if not create it.
@@ -37,8 +37,8 @@ export const ingestionPipeline = async (client: OpenAI, loadpath: string, ragInd
3737
new RedisStorer(ragIndex),
3838
]
3939

40-
stages.forEach((stage) => {
41-
stage.progressReporter = progressReporter.getStageReporter(stage.constructor.name)
40+
stages.forEach((stage, idx) => {
41+
stage.progressReporter = progressReporter.getStageReporter(IngestionPipelineStageKeys[idx], idx)
4242
})
4343

4444
pipeline(stages)
@@ -47,8 +47,8 @@ export const ingestionPipeline = async (client: OpenAI, loadpath: string, ragInd
4747
progressReporter.emit('end')
4848
})
4949
.catch((error) => {
50-
console.error('Pipeline error:', error)
51-
progressReporter.emit('error', 'Pipeline error')
50+
console.error('Unhandled pipeline error:', error)
51+
progressReporter.reportError(error)
5252
})
5353

5454
return progressReporter

0 commit comments

Comments
 (0)