Skip to content

Commit 0f5eaea

Browse files
committed
implement mock model tool calling so rag tests work again
1 parent e115ec7 commit 0f5eaea

File tree

7 files changed

+105
-33
lines changed

7 files changed

+105
-33
lines changed

e2e/courseChatRag.spec.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ test.describe('Course Chat v2', () => {
2020
await expect(page.getByTestId('assistant-message')).toContainText('You are calling mock endpoint for streaming mock data')
2121
})
2222

23-
/* test('Course chat RAG feature', async ({ page }) => {
23+
test.only('Course chat RAG feature', async ({ page }) => {
2424
const ragName = `rag-${test.info().workerIndex}`
2525
await page.locator('#rag-index-selector').first().click()
2626
await page.getByRole('menuitem', { name: ragName }).click()
@@ -30,10 +30,10 @@ test.describe('Course Chat v2', () => {
3030
await chatInput.press('Shift+Enter')
3131

3232
// Shows file search loading indicator
33-
await expect(page.getByTestId('file-searching-message')).toBeVisible()
33+
await expect(page.getByTestId('tool-call-message')).toBeVisible()
3434

35-
// Responds with RAG mock text
36-
await expect(page.getByTestId('assistant-message')).toContainText('This is a mock response for file search stream.')
35+
// Responds with RAG mock document text
36+
await expect(page.getByTestId('assistant-message')).toContainText('This is the first mock document')
3737

3838
// Source button is visible
3939
await expect(page.getByTestId('file-search-sources')).toBeVisible()
@@ -43,5 +43,5 @@ test.describe('Course Chat v2', () => {
4343

4444
// Three source items should be visible
4545
await expect(page.getByTestId('sources-truncated-item')).toHaveCount(3)
46-
}) */
46+
})
4747
})

src/server/routes/ai/v3.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import { PostStreamSchemaV3 } from './types'
1515
import { StructuredTool } from '@langchain/core/tools'
1616
import { getRagIndexSearchTool } from '../../services/rag/searchTool'
1717
import { ChatEvent } from '../../../shared/chat'
18+
import { getMockRagIndexSearchTool } from '../../services/rag/mockSearchTool'
1819

1920
const router = express.Router()
2021

@@ -126,7 +127,11 @@ router.post('/stream', upload.single('file'), async (r, res) => {
126127
return
127128
}
128129

129-
tools.push(getRagIndexSearchTool(ragIndex))
130+
const searchTool = model === 'mock' ? getMockRagIndexSearchTool(ragIndex) : getRagIndexSearchTool(ragIndex)
131+
132+
console.log('Tool given: ' + searchTool.name)
133+
134+
tools.push(searchTool)
130135
}
131136

132137
// Prepare for streaming response

src/server/routes/testUtils.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@ import { inProduction } from '../../config'
33
import { getTestUserHeaders, TEST_COURSES } from '../../shared/testData'
44
import { ChatInstanceRagIndex, Enrolment, Prompt, RagIndex, User, UserChatInstanceUsage } from '../db/models'
55
import { headersToUser } from '../middleware/user'
6-
import type { RequestWithUser } from '../types'
76
import { ApplicationError } from '../util/ApplicationError'
87
import { getCompletionEvents } from '../util/azure/client'
98
import logger from '../util/logger'
10-
import getEncoding from '../util/tiktoken'
119

1210
const router = Router()
1311

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
2-
import { isSystemMessage } from '@langchain/core/messages'
3-
import { BaseMessage } from '@langchain/core/messages'
4-
import { ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
1+
import type { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
2+
import { AIMessage, AIMessageChunk, type BaseMessage, isHumanMessage, isSystemMessage, isToolMessage } from '@langchain/core/messages'
3+
import type { ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
54
import { FakeStreamingChatModel } from '@langchain/core/utils/testing'
65
import { basicTestContent } from '../../util/azure/mocks/mockContent'
76

@@ -17,13 +16,26 @@ export class MockModel extends FakeStreamingChatModel {
1716
})
1817
}
1918

20-
async _generate(messages: BaseMessage[], _options: this['ParsedCallOptions'], _runManager?: CallbackManagerForLLMRun): Promise<ChatResult> {
21-
const firstMessage = messages[0]
22-
if (isSystemMessage(firstMessage) && (firstMessage.content as string).startsWith('mocktest')) {
19+
setupTestResponse(messages: BaseMessage[]) {
20+
const firstSystemMessage = messages.find(isSystemMessage)
21+
const lastHumanMessage = messages.findLast(isHumanMessage)
22+
const toolMessage = isToolMessage(messages[messages.length - 1]) ? messages[messages.length - 1] : null
23+
24+
if (toolMessage) {
25+
this.chunks = [new AIMessageChunk(`Ok! Got some great results from that mock tool call!: "${toolMessage.content}"`)]
26+
} else if (firstSystemMessage && (firstSystemMessage.content as string).startsWith('mocktest')) {
27+
// testing a system message
2328
// Do nothing. FakeStreamingChatModel echoes the first message.
29+
} else if (((lastHumanMessage?.content ?? '') as string).startsWith('rag')) {
30+
// Do a tool call
31+
this.chunks = toolCallChunks
2432
} else {
25-
firstMessage.content = basicTestContent
33+
this.responses = defaultResponse
2634
}
35+
}
36+
37+
async _generate(messages: BaseMessage[], _options: this['ParsedCallOptions'], _runManager?: CallbackManagerForLLMRun): Promise<ChatResult> {
38+
this.setupTestResponse(messages)
2739
return super._generate(messages, _options, _runManager)
2840
}
2941

@@ -32,12 +44,22 @@ export class MockModel extends FakeStreamingChatModel {
3244
_options: this['ParsedCallOptions'],
3345
runManager?: CallbackManagerForLLMRun,
3446
): AsyncGenerator<ChatGenerationChunk> {
35-
const firstMessage = messages[0]
36-
if (isSystemMessage(firstMessage) && (firstMessage.content as string).startsWith('mocktest')) {
37-
// Do nothing. FakeStreamingChatModel echoes the first message.
38-
} else {
39-
firstMessage.content = basicTestContent
40-
}
47+
this.setupTestResponse(messages)
4148
yield* super._streamResponseChunks(messages, _options, runManager)
4249
}
4350
}
51+
52+
const defaultResponse = [new AIMessage(basicTestContent)]
53+
54+
const toolCallChunks = [
55+
new AIMessageChunk({
56+
content: '',
57+
tool_call_chunks: [
58+
{
59+
name: 'mock_document_search',
60+
args: JSON.stringify({ query: 'mock test query' }),
61+
id: 'mock_document_search_id',
62+
},
63+
],
64+
}),
65+
]

src/server/services/langchain/chat.ts

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@ const getChatModel = (model: string, tools: StructuredTool[]): BaseChatModel =>
1818
throw new Error(`Invalid model: ${model}`)
1919
}
2020

21-
if (deploymentName === 'mock') {
22-
return new MockModel()
23-
}
21+
const chatModel =
22+
deploymentName === 'mock'
23+
? new MockModel()
24+
: new AzureChatOpenAI({
25+
model,
26+
azureOpenAIApiKey: AZURE_API_KEY,
27+
azureOpenAIApiVersion: '2023-05-15',
28+
azureOpenAIApiDeploymentName: deploymentName,
29+
azureOpenAIApiInstanceName: AZURE_RESOURCE,
30+
})
31+
32+
chatModel.bindTools(tools)
2433

25-
return new AzureChatOpenAI({
26-
model,
27-
azureOpenAIApiKey: AZURE_API_KEY,
28-
azureOpenAIApiVersion: '2023-05-15',
29-
azureOpenAIApiDeploymentName: deploymentName,
30-
azureOpenAIApiInstanceName: AZURE_RESOURCE,
31-
}).bindTools(tools) as BaseChatModel
34+
return chatModel
3235
}
3336

3437
type WriteEventFunction = (data: ChatEvent) => Promise<void>
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import { tool } from '@langchain/core/tools'
2+
import { z } from 'zod/v4'
3+
import type { RagIndex } from '../../db/models'
4+
import { SearchSchema } from '../../../shared/rag'
5+
import { Document } from '@langchain/core/documents'
6+
import type { search } from './search'
7+
import type { getRagIndexSearchTool } from './searchTool'
8+
9+
const mockDocuments = [
10+
new Document({ pageContent: 'This is the first mock document.', metadata: { ragFileName: 'mock_document1.pdf' } }),
11+
new Document({ pageContent: 'This is the second mock document.', metadata: { ragFileName: 'mock_document2.pdf' } }),
12+
new Document({ pageContent: 'This is the third mock document.', metadata: { ragFileName: 'mock_document3.pdf' } }),
13+
]
14+
15+
const mockSearch: typeof search = async (_index: RagIndex, _params: { query: string }) => {
16+
await new Promise((resolve) => setTimeout(resolve, 300))
17+
18+
return {
19+
results: mockDocuments.map((doc) => ({
20+
id: doc.id,
21+
content: doc.pageContent,
22+
metadata: doc.metadata,
23+
})),
24+
timings: { search: 1000 },
25+
}
26+
}
27+
28+
export const getMockRagIndexSearchTool: typeof getRagIndexSearchTool = (ragIndex: RagIndex) =>
29+
tool(
30+
async ({ query }: { query: string }) => {
31+
console.log('Mock search tool invoked with query:', query)
32+
const { results: documents } = await mockSearch(ragIndex, SearchSchema.parse({ query }))
33+
// With responseFormat: content_and_artifact, return content and artifact like this:
34+
return [documents.map((doc) => doc.content).join('\n\n'), documents]
35+
},
36+
{
37+
name: `mock_document_search`, // Gotcha: function name must match '^[a-zA-Z0-9_\.-]+$' at least in AzureOpenAI. This name must satisfy the name in ChatToolDef type
38+
description: `Search documents in the materials (titled '${ragIndex.metadata.name}'). Prefer ${ragIndex.metadata.language}, which is the language used in the documents.`,
39+
schema: z.object({
40+
query: z.string().describe('the query to search for'),
41+
}),
42+
responseFormat: 'content_and_artifact',
43+
},
44+
)

src/shared/tools.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { RagChunk } from './rag'
22

33
export type ChatToolDef = {
4-
name: 'document_search'
4+
name: 'document_search' | 'mock_document_search'
55
input: { query: string }
66
result: { files: { fileName: string; score?: number }[] }
77
output: RagChunk[]

0 commit comments

Comments
 (0)