Skip to content

Commit 0b1ec8e

Browse files
committed
toolcall typing & ingestion improvements
1 parent 4206ef3 commit 0b1ec8e

File tree

7 files changed

+49
-114
lines changed

7 files changed

+49
-114
lines changed

src/client/components/ChatV2/Conversation.tsx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import { t } from 'i18next'
1818
import FormatQuoteIcon from '@mui/icons-material/FormatQuote'
1919
import useLocalStorageState from '../../hooks/useLocalStorageState'
2020
import { BlueButton } from './general/Buttons'
21+
import { ToolCallStatusEvent } from '../../../shared/chat'
2122

2223
const UserMessage = ({ content, attachements }: { content: string; attachements?: string }) => (
2324
<Box
@@ -324,7 +325,7 @@ export const Conversation = ({
324325
expandedNodeHeight: number
325326
messages: Message[]
326327
completion: string
327-
toolCalls: { id: string; name?: string }[]
328+
toolCalls: { [callId: string]: ToolCallStatusEvent }
328329
isStreaming: boolean
329330
setActiveFileSearchResult: (data: FileSearchCompletedData) => void
330331
setShowFileSearchResults: (show: boolean) => void
@@ -371,7 +372,7 @@ export const Conversation = ({
371372
setShowFileSearchResults={setShowFileSearchResults}
372373
/>
373374
) : (
374-
<LoadingMessage expandedNodeHeight={expandedNodeHeight} isFileSearching={toolCalls.length > 0} />
375+
<LoadingMessage expandedNodeHeight={expandedNodeHeight} isFileSearching={Object.values(toolCalls).some((call) => !!call.result)} />
375376
))}
376377
</Box>
377378
{!reminderSeen && !isStreaming && messages.length > 15 && (

src/client/components/ChatV2/useChatStream.ts

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,9 @@
11
import { useState } from 'react'
2-
import type { FileSearchCompletedData, ResponseStreamEventData } from '../../../shared/types'
2+
import type { FileSearchCompletedData } from '../../../shared/types'
33
import type { Message } from '../../types'
4-
import { ChatToolDef } from '../../../shared/tools'
4+
import { ChatEvent, ToolCallStatusEvent } from '../../../shared/chat'
55

6-
type ToolCallState = { id: string; name: ChatToolDef['name'] } & (
7-
| {
8-
status: 'starting'
9-
}
10-
| {
11-
input: ChatToolDef['input']
12-
status: 'started'
13-
}
14-
| {
15-
status: 'completed'
16-
input?: ChatToolDef['input'] // Have to allow optional because it is possible "started" state was not reached before completed.
17-
artifacts: ChatToolDef['artifacts']
18-
}
19-
)
6+
type ToolCallState = ToolCallStatusEvent
207

218
export const useChatStream = ({
229
onFileSearchComplete,
@@ -31,7 +18,7 @@ export const useChatStream = ({
3118
}) => {
3219
const [completion, setCompletion] = useState('')
3320
const [isStreaming, setIsStreaming] = useState(false)
34-
const [toolCalls, setToolCalls] = useState<ToolCallState[]>([])
21+
const [toolCalls, setToolCalls] = useState<{ [callId: string]: ToolCallState }>({})
3522
const [streamController, setStreamController] = useState<AbortController>()
3623

3724
const decoder = new TextDecoder()
@@ -55,7 +42,7 @@ export const useChatStream = ({
5542
for (const chunk of data.split('\n')) {
5643
if (!chunk || chunk.trim().length === 0) continue
5744

58-
let parsedChunk: ResponseStreamEventData | undefined
45+
let parsedChunk: ChatEvent | undefined
5946
try {
6047
parsedChunk = JSON.parse(chunk)
6148
} catch (e: any) {
@@ -81,29 +68,8 @@ export const useChatStream = ({
8168
content += parsedChunk.text
8269
break
8370

84-
case 'toolCallStarting':
85-
setToolCalls((prev) => [
86-
...prev,
87-
{
88-
id: parsedChunk.id,
89-
name: parsedChunk.name,
90-
status: 'starting',
91-
},
92-
])
93-
break
94-
95-
case 'toolCallStarted':
96-
setToolCalls((prev) =>
97-
prev.map((call) =>
98-
call.id === parsedChunk.id && call.status === 'starting' ? { ...call, status: 'started', input: parsedChunk.input } : call,
99-
),
100-
)
101-
break
102-
103-
case 'toolCallCompleted':
104-
setToolCalls((prev) =>
105-
prev.map((call) => (call.id === parsedChunk.id ? { ...call, status: 'completed', artifacts: parsedChunk.artifacts } : call)),
106-
)
71+
case 'toolCallStatus':
72+
setToolCalls((prev) => ({ ...prev, [parsedChunk.callId]: parsedChunk }))
10773
break
10874

10975
case 'error':
@@ -125,7 +91,7 @@ export const useChatStream = ({
12591
} finally {
12692
setStreamController(undefined)
12793
setCompletion('')
128-
setToolCalls([])
94+
setToolCalls({})
12995
setIsStreaming(false)
13096

13197
onComplete({

src/server/routes/ai/v3.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import express from 'express'
22
import { DEFAULT_TOKEN_LIMIT, FREE_MODEL, inProduction } from '../../../config'
33
import type { ChatMessage } from '../../../shared/llmTypes'
4-
import type { ResponseStreamEventData } from '../../../shared/types'
54
import { ChatInstance, Discussion, RagIndex, UserChatInstanceUsage } from '../../db/models'
65
import { calculateUsage, checkCourseUsage, checkUsage, incrementCourseUsage, incrementUsage } from '../../services/chatInstances/usage'
76
import { streamChat } from '../../services/langchain/chat'
@@ -15,6 +14,7 @@ import { upload } from './multer'
1514
import { PostStreamSchemaV3 } from './types'
1615
import { StructuredTool } from '@langchain/core/tools'
1716
import { getRagIndexSearchTool } from '../../services/rag/searchTool'
17+
import { ChatEvent } from '../../../shared/chat'
1818

1919
const router = express.Router()
2020

@@ -137,7 +137,7 @@ router.post('/stream', upload.single('file'), async (r, res) => {
137137
systemMessage: options.systemMessage,
138138
model: options.model,
139139
tools,
140-
writeEvent: async (event: ResponseStreamEventData) => {
140+
writeEvent: async (event: ChatEvent) => {
141141
await new Promise((resolve) => {
142142
const success = res.write(`${JSON.stringify(event)}\n`, (err) => {
143143
if (err) {

src/server/routes/rag/ragIndex.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ const upload = multer({
181181
bucket: S3_BUCKET,
182182
acl: 'private',
183183
metadata: (req, file, cb) => {
184-
cb(null, { fieldName: file.fieldName })
184+
cb(null, { fieldName: file.fieldname })
185185
},
186186
key: (req, file, cb) => {
187187
const { ragIndex } = req as RagIndexRequest

src/server/services/langchain/chat.ts

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
import { AzureChatOpenAI } from '@langchain/openai'
2-
import { AZURE_API_KEY, AZURE_RESOURCE } from '../../util/config'
3-
import { validModels } from '../../../config'
4-
import { ChatMessage, Message } from '../../../shared/llmTypes'
5-
import { Response } from 'express'
6-
import { AIMessageChunk, BaseMessage, BaseMessageLike } from '@langchain/core/messages'
7-
import { IterableReadableStream } from '@langchain/core/utils/stream'
8-
import { ResponseStreamEventData } from '../../../shared/types'
9-
import { Tiktoken } from '@dqbd/tiktoken'
10-
import { FakeStreamingChatModel } from '@langchain/core/utils/testing'
11-
import { MockModel } from './MockModel'
1+
import { BaseChatModel } from '@langchain/core/language_models/chat_models'
2+
import { AIMessageChunk, BaseMessageLike } from '@langchain/core/messages'
123
import { StructuredTool } from '@langchain/core/tools'
134
import { concat } from '@langchain/core/utils/stream'
14-
import { BaseChatModel } from '@langchain/core/language_models/chat_models'
15-
import { Runnable } from '@langchain/core/runnables'
16-
import logger from '../../util/logger'
17-
import { BaseLanguageModelInput } from '@langchain/core/language_models/base'
18-
import { ToolCall } from '@langchain/core/messages/tool'
5+
import { AzureChatOpenAI } from '@langchain/openai'
6+
import { validModels } from '../../../config'
7+
import { ChatEvent } from '../../../shared/chat'
8+
import { ChatMessage } from '../../../shared/llmTypes'
199
import { ChatToolDef } from '../../../shared/tools'
10+
import { AZURE_API_KEY, AZURE_RESOURCE } from '../../util/config'
11+
import { MockModel } from './MockModel'
2012

2113
const getChatModel = (model: string, tools: StructuredTool[]): BaseChatModel => {
2214
const deploymentName = validModels.find((m) => m.name === model)?.deployment
@@ -37,7 +29,7 @@ const getChatModel = (model: string, tools: StructuredTool[]): BaseChatModel =>
3729
}).bindTools(tools) as BaseChatModel
3830
}
3931

40-
type WriteEventFunction = (data: ResponseStreamEventData) => Promise<void>
32+
type WriteEventFunction = (data: ChatEvent) => Promise<void>
4133

4234
type ChatTool = StructuredTool<any, any, any, string>
4335

@@ -115,9 +107,10 @@ const chatTurn = async (model: BaseChatModel, messages: BaseMessageLike[], tools
115107
status: 'pending',
116108
}
117109
await writeEvent({
118-
type: 'toolCallStarting',
119-
name: toolCall.name as ChatToolDef['name'],
120-
id,
110+
type: 'toolCallStatus',
111+
toolName: toolCall.name as ChatToolDef['name'],
112+
callId: id,
113+
text: 'Starting',
121114
})
122115
}
123116
}
@@ -143,21 +136,25 @@ const chatTurn = async (model: BaseChatModel, messages: BaseMessageLike[], tools
143136
const input = toolCall.args as ChatToolDef['input']
144137
if (id && tool) {
145138
await writeEvent({
146-
type: 'toolCallStarted',
147-
name,
148-
input,
149-
id,
139+
type: 'toolCallStatus',
140+
toolName: name,
141+
callId: id,
142+
text: `Searching for '${input.query}'`,
150143
})
151144
const result = await tool.invoke(toolCall)
152145
messages.push(result)
153146
toolCallStatuses[id] = {
154147
status: 'completed',
155148
}
156149
await writeEvent({
157-
type: 'toolCallCompleted',
158-
name,
159-
artifacts: result.artifact,
160-
id,
150+
type: 'toolCallStatus',
151+
toolName: name,
152+
callId: id,
153+
text: 'Completed search',
154+
result: {
155+
artifacts: result.artifact,
156+
input,
157+
},
161158
})
162159
}
163160
}

src/server/services/rag/ingestion.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export const ingestRagFiles = async (ragIndex: RagIndex) => {
2626

2727
const vectorStore = getRedisVectorStore(ragFiles[0].ragIndexId, ragIndex.metadata.language)
2828
const allDocuments: Document[] = []
29+
const allEmbeddings: number[][] = []
2930

3031
for (const ragFile of ragFiles) {
3132
console.time(`Ingestion ${ragFile.filename}`)
@@ -40,24 +41,27 @@ export const ingestRagFiles = async (ragIndex: RagIndex) => {
4041

4142
const chunkDocuments = await splitter.splitDocuments([document])
4243

43-
chunkDocuments.forEach((chunkDocument, idx) => {
44+
let idx = 0
45+
for (const chunkDocument of chunkDocuments) {
4446
chunkDocument.id = `ragIndex-${ragFile.ragIndexId}-${ragFile.filename}-${idx}`
4547
chunkDocument.metadata = {
4648
...chunkDocument.metadata,
4749
ragFileName: ragFile.filename,
4850
}
49-
})
51+
idx++
52+
}
53+
54+
const embeddings = await vectorStore.embeddings.embedDocuments(chunkDocuments.map((d) => d.pageContent))
5055

51-
// console.log(await redisClient.ft.info(vectorStore.indexName))
5256
allDocuments.push(...chunkDocuments)
57+
allEmbeddings.push(...embeddings)
5358

5459
console.timeEnd(`Ingestion ${ragFile.filename}`)
55-
// console.log(await redisClient.ft.info(vectorStore.indexName))
56-
//
60+
5761
ragFile.pipelineStage = 'completed'
5862
await ragFile.save()
5963
}
6064

6165
// @todo we can only call this once. How to handle new documents?
62-
await vectorStore.addDocuments(allDocuments)
66+
await vectorStore.addVectors(allEmbeddings, allDocuments)
6367
}

src/shared/types.ts

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,39 +52,6 @@ export type FileSearchCompletedData = {
5252

5353
export type FileSearchResultData = Record<string, unknown>
5454

55-
export type ResponseStreamEventData =
56-
| {
57-
type: 'writing'
58-
text: string
59-
}
60-
| {
61-
type: 'complete'
62-
prevResponseId: string // Not yet sure what to put here in v3
63-
}
64-
| ({
65-
type: 'toolCallStarting'
66-
id: string
67-
} & Pick<ChatToolDef, 'name'>)
68-
| ({
69-
type: 'toolCallStarted'
70-
id: string
71-
} & Pick<ChatToolDef, 'name' | 'input'>)
72-
| ({
73-
type: 'toolCallCompleted'
74-
id: string
75-
} & Pick<ChatToolDef, 'name' | 'artifacts'>)
76-
| {
77-
type: 'error'
78-
error: string
79-
}
80-
81-
export interface CourseAssistant {
82-
course_id: string | null
83-
name: string
84-
assistant_instruction: string
85-
vector_store_id: string | null
86-
}
87-
8855
export type Locale = {
8956
fi?: string
9057
en?: string

0 commit comments

Comments
 (0)