Skip to content

Commit 532cb9a

Browse files
committed
feat: enhance stop token handling and add streaming support in Mistral API tests
1 parent dc97dfb commit 532cb9a

File tree

1 file changed

+149
-50
lines changed

1 file changed

+149
-50
lines changed

src/api/providers/__tests__/mistral.test.ts

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,18 @@ import { ApiStreamTextChunk } from "../../transform/stream"
77
const mockCreate = jest.fn()
88
jest.mock("@mistralai/mistralai", () => {
99
return {
10-
Mistral: jest.fn().mockImplementation(() => ({
10+
MistralClient: jest.fn().mockImplementation(() => ({
1111
chat: {
1212
stream: mockCreate.mockImplementation(async (options) => {
1313
const stream = {
1414
[Symbol.asyncIterator]: async function* () {
1515
yield {
16-
data: {
17-
choices: [
18-
{
19-
delta: { content: "Test response" },
20-
index: 0,
21-
},
22-
],
23-
},
16+
choices: [
17+
{
18+
delta: { content: "Test response" },
19+
index: 0,
20+
},
21+
],
2422
}
2523
},
2624
}
@@ -37,10 +35,13 @@ describe("MistralHandler", () => {
3735

3836
beforeEach(() => {
3937
mockOptions = {
40-
apiModelId: "codestral-latest", // Update to match the actual model ID
38+
apiModelId: "codestral-latest",
4139
mistralApiKey: "test-api-key",
4240
includeMaxTokens: true,
4341
modelTemperature: 0,
42+
mistralModelStreamingEnabled: true,
43+
stopToken: undefined,
44+
mistralCodestralUrl: undefined,
4445
}
4546
handler = new MistralHandler(mockOptions)
4647
mockCreate.mockClear()
@@ -60,23 +61,91 @@ describe("MistralHandler", () => {
6061
})
6162
}).toThrow("Mistral API key is required")
6263
})
64+
})
65+
66+
describe("stopToken handling", () => {
67+
const systemPrompt = "You are a helpful assistant."
68+
const messages: Anthropic.Messages.MessageParam[] = [
69+
{
70+
role: "user",
71+
content: [{ type: "text", text: "Hello!" }],
72+
},
73+
]
6374

64-
it("should use custom base URL if provided", () => {
65-
const customBaseUrl = "https://custom.mistral.ai/v1"
66-
const handlerWithCustomUrl = new MistralHandler({
75+
it("should not include stop parameter when stopToken is undefined", async () => {
76+
const handlerWithoutStop = new MistralHandler({
6777
...mockOptions,
68-
mistralCodestralUrl: customBaseUrl,
78+
stopToken: undefined,
6979
})
70-
expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
80+
await handlerWithoutStop.createMessage(systemPrompt, messages)
81+
82+
expect(mockCreate).toHaveBeenCalled()
83+
const callArgs = mockCreate.mock.calls[0][0]
84+
expect(callArgs).not.toHaveProperty("stop")
7185
})
72-
})
7386

74-
describe("getModel", () => {
75-
it("should return correct model info", () => {
76-
const model = handler.getModel()
77-
expect(model.id).toBe(mockOptions.apiModelId)
78-
expect(model.info).toBeDefined()
79-
expect(model.info.supportsPromptCache).toBe(false)
87+
it("should not include stop parameter when stopToken is empty string", async () => {
88+
const handlerWithEmptyStop = new MistralHandler({
89+
...mockOptions,
90+
stopToken: "",
91+
})
92+
await handlerWithEmptyStop.createMessage(systemPrompt, messages)
93+
94+
expect(mockCreate).toHaveBeenCalled()
95+
const callArgs = mockCreate.mock.calls[0][0]
96+
expect(callArgs).not.toHaveProperty("stop")
97+
})
98+
99+
it("should not include stop parameter when stopToken contains only whitespace", async () => {
100+
const handlerWithWhitespaceStop = new MistralHandler({
101+
...mockOptions,
102+
stopToken: " ",
103+
})
104+
await handlerWithWhitespaceStop.createMessage(systemPrompt, messages)
105+
106+
expect(mockCreate).toHaveBeenCalled()
107+
const callArgs = mockCreate.mock.calls[0][0]
108+
expect(callArgs).not.toHaveProperty("stop")
109+
})
110+
111+
it("should not include stop parameter when stopToken contains only commas", async () => {
112+
const handlerWithCommasStop = new MistralHandler({
113+
...mockOptions,
114+
stopToken: ",,,",
115+
})
116+
await handlerWithCommasStop.createMessage(systemPrompt, messages)
117+
118+
expect(mockCreate).toHaveBeenCalled()
119+
const callArgs = mockCreate.mock.calls[0][0]
120+
expect(callArgs).not.toHaveProperty("stop")
121+
})
122+
123+
it("should include stop parameter with single token", async () => {
124+
const handlerWithStop = new MistralHandler({
125+
...mockOptions,
126+
stopToken: "\\n\\n",
127+
})
128+
await handlerWithStop.createMessage(systemPrompt, messages)
129+
130+
expect(mockCreate).toHaveBeenCalledWith(
131+
expect.objectContaining({
132+
stop: ["\\n\\n"],
133+
}),
134+
)
135+
})
136+
137+
it("should handle multiple stop tokens and filter empty ones", async () => {
138+
const handlerWithMultiStop = new MistralHandler({
139+
...mockOptions,
140+
stopToken: "\\n\\n,,DONE, ,END,",
141+
})
142+
await handlerWithMultiStop.createMessage(systemPrompt, messages)
143+
144+
expect(mockCreate).toHaveBeenCalledWith(
145+
expect.objectContaining({
146+
stop: ["\\n\\n", "DONE", "END"],
147+
}),
148+
)
80149
})
81150
})
82151

@@ -89,38 +158,68 @@ describe("MistralHandler", () => {
89158
},
90159
]
91160

92-
it("should create message successfully", async () => {
93-
const iterator = handler.createMessage(systemPrompt, messages)
94-
const result = await iterator.next()
95-
96-
expect(mockCreate).toHaveBeenCalledWith({
97-
model: mockOptions.apiModelId,
98-
messages: expect.any(Array),
99-
maxTokens: expect.any(Number),
100-
temperature: 0,
101-
})
102-
103-
expect(result.value).toBeDefined()
104-
expect(result.done).toBe(false)
161+
it("should create message with streaming enabled", async () => {
162+
const stream = await handler.createMessage(systemPrompt, messages)
163+
expect(stream).toBeDefined()
164+
expect(mockCreate).toHaveBeenCalledWith(
165+
expect.objectContaining({
166+
messages: expect.arrayContaining([
167+
expect.objectContaining({
168+
role: "system",
169+
content: systemPrompt,
170+
}),
171+
]),
172+
stream: true,
173+
}),
174+
)
105175
})
106176

107-
it("should handle streaming response correctly", async () => {
108-
const iterator = handler.createMessage(systemPrompt, messages)
109-
const results: ApiStreamTextChunk[] = []
110-
111-
for await (const chunk of iterator) {
112-
if ("text" in chunk) {
113-
results.push(chunk as ApiStreamTextChunk)
114-
}
115-
}
116-
117-
expect(results.length).toBeGreaterThan(0)
118-
expect(results[0].text).toBe("Test response")
177+
it("should handle temperature settings", async () => {
178+
const handlerWithTemp = new MistralHandler({
179+
...mockOptions,
180+
modelTemperature: 0.7,
181+
})
182+
await handlerWithTemp.createMessage(systemPrompt, messages)
183+
expect(mockCreate).toHaveBeenCalledWith(
184+
expect.objectContaining({
185+
temperature: 0.7,
186+
}),
187+
)
119188
})
120189

121-
it("should handle errors gracefully", async () => {
122-
mockCreate.mockRejectedValueOnce(new Error("API Error"))
123-
await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
190+
it("should transform messages correctly", async () => {
191+
const complexMessages: Anthropic.Messages.MessageParam[] = [
192+
{
193+
role: "user",
194+
content: [
195+
{ type: "text", text: "Hello!" },
196+
{ type: "text", text: "How are you?" },
197+
],
198+
},
199+
{
200+
role: "assistant",
201+
content: [{ type: "text", text: "I'm doing well!" }],
202+
},
203+
]
204+
await handler.createMessage(systemPrompt, complexMessages)
205+
expect(mockCreate).toHaveBeenCalledWith(
206+
expect.objectContaining({
207+
messages: expect.arrayContaining([
208+
expect.objectContaining({
209+
role: "system",
210+
content: systemPrompt,
211+
}),
212+
expect.objectContaining({
213+
role: "user",
214+
content: "Hello! How are you?",
215+
}),
216+
expect.objectContaining({
217+
role: "assistant",
218+
content: "I'm doing well!",
219+
}),
220+
]),
221+
}),
222+
)
124223
})
125224
})
126225
})

0 commit comments

Comments
 (0)