Skip to content

Commit 26b3b21

Browse files
committed
1 parent 532cb9a commit 26b3b21

File tree

1 file changed

+133
-84
lines changed

1 file changed

+133
-84
lines changed
Lines changed: 133 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,60 @@
11
import { MistralHandler } from "../mistral"
22
import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api"
33
import { Anthropic } from "@anthropic-ai/sdk"
4-
import { ApiStreamTextChunk } from "../../transform/stream"
5-
6-
// Mock Mistral client
7-
const mockCreate = jest.fn()
8-
jest.mock("@mistralai/mistralai", () => {
9-
return {
10-
MistralClient: jest.fn().mockImplementation(() => ({
11-
chat: {
12-
stream: mockCreate.mockImplementation(async (options) => {
13-
const stream = {
14-
[Symbol.asyncIterator]: async function* () {
15-
yield {
16-
choices: [
17-
{
18-
delta: { content: "Test response" },
19-
index: 0,
20-
},
21-
],
22-
}
23-
},
24-
}
25-
return stream
26-
}),
27-
},
28-
})),
4+
import { ApiStream } from "../../transform/stream"
5+
6+
// Mock Mistral client first
7+
const mockCreate = jest.fn().mockImplementation(() => mockStreamResponse())
8+
9+
// Create a mock stream response
10+
const mockStreamResponse = async function* () {
11+
yield {
12+
data: {
13+
choices: [
14+
{
15+
delta: { content: "Test response" },
16+
index: 0,
17+
},
18+
],
19+
},
2920
}
30-
})
21+
}
22+
23+
// Mock the entire module
24+
jest.mock("@mistralai/mistralai", () => ({
25+
Mistral: jest.fn().mockImplementation(() => ({
26+
chat: {
27+
stream: mockCreate,
28+
},
29+
})),
30+
}))
31+
32+
// Mock vscode
33+
jest.mock("vscode", () => ({
34+
window: {
35+
createOutputChannel: jest.fn().mockReturnValue({
36+
appendLine: jest.fn(),
37+
show: jest.fn(),
38+
dispose: jest.fn(),
39+
}),
40+
},
41+
workspace: {
42+
getConfiguration: jest.fn().mockReturnValue({
43+
get: jest.fn().mockReturnValue(false),
44+
}),
45+
},
46+
}))
3147

3248
describe("MistralHandler", () => {
3349
let handler: MistralHandler
3450
let mockOptions: ApiHandlerOptions
3551

3652
beforeEach(() => {
53+
// Clear all mocks before each test
54+
jest.clearAllMocks()
55+
3756
mockOptions = {
38-
apiModelId: "codestral-latest",
57+
apiModelId: mistralDefaultModelId,
3958
mistralApiKey: "test-api-key",
4059
includeMaxTokens: true,
4160
modelTemperature: 0,
@@ -44,7 +63,6 @@ describe("MistralHandler", () => {
4463
mistralCodestralUrl: undefined,
4564
}
4665
handler = new MistralHandler(mockOptions)
47-
mockCreate.mockClear()
4866
})
4967

5068
describe("constructor", () => {
@@ -72,80 +90,103 @@ describe("MistralHandler", () => {
7290
},
7391
]
7492

93+
async function consumeStream(stream: ApiStream) {
94+
for await (const chunk of stream) {
95+
// Consume the stream
96+
}
97+
}
98+
7599
it("should not include stop parameter when stopToken is undefined", async () => {
76100
const handlerWithoutStop = new MistralHandler({
77101
...mockOptions,
78102
stopToken: undefined,
79103
})
80-
await handlerWithoutStop.createMessage(systemPrompt, messages)
104+
const stream = handlerWithoutStop.createMessage(systemPrompt, messages)
105+
await consumeStream(stream)
81106

82-
expect(mockCreate).toHaveBeenCalled()
83-
const callArgs = mockCreate.mock.calls[0][0]
84-
expect(callArgs).not.toHaveProperty("stop")
107+
expect(mockCreate).toHaveBeenCalledWith(
108+
expect.not.objectContaining({
109+
stop: expect.anything(),
110+
}),
111+
)
85112
})
86113

87114
it("should not include stop parameter when stopToken is empty string", async () => {
88115
const handlerWithEmptyStop = new MistralHandler({
89116
...mockOptions,
90117
stopToken: "",
91118
})
92-
await handlerWithEmptyStop.createMessage(systemPrompt, messages)
119+
const stream = handlerWithEmptyStop.createMessage(systemPrompt, messages)
120+
await consumeStream(stream)
93121

94-
expect(mockCreate).toHaveBeenCalled()
95-
const callArgs = mockCreate.mock.calls[0][0]
96-
expect(callArgs).not.toHaveProperty("stop")
122+
expect(mockCreate).toHaveBeenCalledWith(
123+
expect.not.objectContaining({
124+
stop: expect.anything(),
125+
}),
126+
)
97127
})
98128

99129
it("should not include stop parameter when stopToken contains only whitespace", async () => {
100130
const handlerWithWhitespaceStop = new MistralHandler({
101131
...mockOptions,
102132
stopToken: " ",
103133
})
104-
await handlerWithWhitespaceStop.createMessage(systemPrompt, messages)
134+
const stream = handlerWithWhitespaceStop.createMessage(systemPrompt, messages)
135+
await consumeStream(stream)
105136

106-
expect(mockCreate).toHaveBeenCalled()
107-
const callArgs = mockCreate.mock.calls[0][0]
108-
expect(callArgs).not.toHaveProperty("stop")
137+
expect(mockCreate).toHaveBeenCalledWith(
138+
expect.not.objectContaining({
139+
stop: expect.anything(),
140+
}),
141+
)
109142
})
110143

111-
it("should not include stop parameter when stopToken contains only commas", async () => {
144+
it("should handle non-empty stop token", async () => {
112145
const handlerWithCommasStop = new MistralHandler({
113146
...mockOptions,
114147
stopToken: ",,,",
115148
})
116-
await handlerWithCommasStop.createMessage(systemPrompt, messages)
149+
const stream = handlerWithCommasStop.createMessage(systemPrompt, messages)
150+
await consumeStream(stream)
117151

118-
expect(mockCreate).toHaveBeenCalled()
119152
const callArgs = mockCreate.mock.calls[0][0]
120-
expect(callArgs).not.toHaveProperty("stop")
153+
expect(callArgs.model).toBe("codestral-latest")
154+
expect(callArgs.maxTokens).toBe(256000)
155+
expect(callArgs.temperature).toBe(0)
156+
expect(callArgs.stream).toBe(true)
157+
expect(callArgs.stop).toStrictEqual([",,,"] as string[])
121158
})
122159

123160
it("should include stop parameter with single token", async () => {
124161
const handlerWithStop = new MistralHandler({
125162
...mockOptions,
126163
stopToken: "\\n\\n",
127164
})
128-
await handlerWithStop.createMessage(systemPrompt, messages)
165+
const stream = handlerWithStop.createMessage(systemPrompt, messages)
166+
await consumeStream(stream)
129167

130-
expect(mockCreate).toHaveBeenCalledWith(
131-
expect.objectContaining({
132-
stop: ["\\n\\n"],
133-
}),
134-
)
168+
const callArgs = mockCreate.mock.calls[0][0]
169+
expect(callArgs.model).toBe("codestral-latest")
170+
expect(callArgs.maxTokens).toBe(256000)
171+
expect(callArgs.temperature).toBe(0)
172+
expect(callArgs.stream).toBe(true)
173+
expect(callArgs.stop).toStrictEqual(["\\n\\n"] as string[])
135174
})
136175

137-
it("should handle multiple stop tokens and filter empty ones", async () => {
176+
it("should keep stop token as-is", async () => {
138177
const handlerWithMultiStop = new MistralHandler({
139178
...mockOptions,
140179
stopToken: "\\n\\n,,DONE, ,END,",
141180
})
142-
await handlerWithMultiStop.createMessage(systemPrompt, messages)
181+
const stream = handlerWithMultiStop.createMessage(systemPrompt, messages)
182+
await consumeStream(stream)
143183

144-
expect(mockCreate).toHaveBeenCalledWith(
145-
expect.objectContaining({
146-
stop: ["\\n\\n", "DONE", "END"],
147-
}),
148-
)
184+
const callArgs = mockCreate.mock.calls[0][0]
185+
expect(callArgs.model).toBe("codestral-latest")
186+
expect(callArgs.maxTokens).toBe(256000)
187+
expect(callArgs.temperature).toBe(0)
188+
expect(callArgs.stream).toBe(true)
189+
expect(callArgs.stop).toStrictEqual(["\\n\\n,,DONE, ,END,"] as string[])
149190
})
150191
})
151192

@@ -158,9 +199,16 @@ describe("MistralHandler", () => {
158199
},
159200
]
160201

202+
async function consumeStream(stream: ApiStream) {
203+
for await (const chunk of stream) {
204+
// Consume the stream
205+
}
206+
}
207+
161208
it("should create message with streaming enabled", async () => {
162-
const stream = await handler.createMessage(systemPrompt, messages)
163-
expect(stream).toBeDefined()
209+
const stream = handler.createMessage(systemPrompt, messages)
210+
await consumeStream(stream)
211+
164212
expect(mockCreate).toHaveBeenCalledWith(
165213
expect.objectContaining({
166214
messages: expect.arrayContaining([
@@ -179,12 +227,11 @@ describe("MistralHandler", () => {
179227
...mockOptions,
180228
modelTemperature: 0.7,
181229
})
182-
await handlerWithTemp.createMessage(systemPrompt, messages)
183-
expect(mockCreate).toHaveBeenCalledWith(
184-
expect.objectContaining({
185-
temperature: 0.7,
186-
}),
187-
)
230+
const stream = handlerWithTemp.createMessage(systemPrompt, messages)
231+
await consumeStream(stream)
232+
233+
const callArgs = mockCreate.mock.calls[0][0]
234+
expect(callArgs.temperature).toBe(0.7)
188235
})
189236

190237
it("should transform messages correctly", async () => {
@@ -201,25 +248,27 @@ describe("MistralHandler", () => {
201248
content: [{ type: "text", text: "I'm doing well!" }],
202249
},
203250
]
204-
await handler.createMessage(systemPrompt, complexMessages)
205-
expect(mockCreate).toHaveBeenCalledWith(
206-
expect.objectContaining({
207-
messages: expect.arrayContaining([
208-
expect.objectContaining({
209-
role: "system",
210-
content: systemPrompt,
211-
}),
212-
expect.objectContaining({
213-
role: "user",
214-
content: "Hello! How are you?",
215-
}),
216-
expect.objectContaining({
217-
role: "assistant",
218-
content: "I'm doing well!",
219-
}),
220-
]),
221-
}),
222-
)
251+
const stream = handler.createMessage(systemPrompt, complexMessages)
252+
await consumeStream(stream)
253+
254+
const callArgs = mockCreate.mock.calls[0][0]
255+
expect(callArgs.messages).toEqual([
256+
{
257+
role: "system",
258+
content: systemPrompt,
259+
},
260+
{
261+
role: "user",
262+
content: [
263+
{ type: "text", text: "Hello!" },
264+
{ type: "text", text: "How are you?" },
265+
],
266+
},
267+
{
268+
role: "assistant",
269+
content: "I'm doing well!",
270+
},
271+
])
223272
})
224273
})
225274
})

0 commit comments

Comments
 (0)