-
Notifications
You must be signed in to change notification settings - Fork 2.6k
fix: handle Mistral thinking content as reasoning chunks #7106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| // Mock Mistral client - must come before other imports | ||||||||||||||||||||||||||||||||||||||||||||||
| const mockCreate = vi.fn() | ||||||||||||||||||||||||||||||||||||||||||||||
| const mockComplete = vi.fn() | ||||||||||||||||||||||||||||||||||||||||||||||
| vi.mock("@mistralai/mistralai", () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||
| Mistral: vi.fn().mockImplementation(() => ({ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -21,6 +22,17 @@ vi.mock("@mistralai/mistralai", () => { | |||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return stream | ||||||||||||||||||||||||||||||||||||||||||||||
| }), | ||||||||||||||||||||||||||||||||||||||||||||||
| complete: mockComplete.mockImplementation(async (_options) => { | ||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||
| choices: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| message: { | ||||||||||||||||||||||||||||||||||||||||||||||
| content: "Test response", | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }), | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| })), | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -29,7 +41,7 @@ vi.mock("@mistralai/mistralai", () => { | |||||||||||||||||||||||||||||||||||||||||||||
| import type { Anthropic } from "@anthropic-ai/sdk" | ||||||||||||||||||||||||||||||||||||||||||||||
| import { MistralHandler } from "../mistral" | ||||||||||||||||||||||||||||||||||||||||||||||
| import type { ApiHandlerOptions } from "../../../shared/api" | ||||||||||||||||||||||||||||||||||||||||||||||
| import type { ApiStreamTextChunk } from "../../transform/stream" | ||||||||||||||||||||||||||||||||||||||||||||||
| import type { ApiStreamTextChunk, ApiStreamReasoningChunk } from "../../transform/stream" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| describe("MistralHandler", () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| let handler: MistralHandler | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -44,6 +56,7 @@ describe("MistralHandler", () => { | |||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| handler = new MistralHandler(mockOptions) | ||||||||||||||||||||||||||||||||||||||||||||||
| mockCreate.mockClear() | ||||||||||||||||||||||||||||||||||||||||||||||
| mockComplete.mockClear() | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| describe("constructor", () => { | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -122,5 +135,128 @@ describe("MistralHandler", () => { | |||||||||||||||||||||||||||||||||||||||||||||
| mockCreate.mockRejectedValueOnce(new Error("API Error")) | ||||||||||||||||||||||||||||||||||||||||||||||
| await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error") | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| it("should handle thinking content as reasoning chunks", async () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| // Mock stream with thinking content | ||||||||||||||||||||||||||||||||||||||||||||||
| mockCreate.mockImplementationOnce(async (_options) => { | ||||||||||||||||||||||||||||||||||||||||||||||
| const stream = { | ||||||||||||||||||||||||||||||||||||||||||||||
| [Symbol.asyncIterator]: async function* () { | ||||||||||||||||||||||||||||||||||||||||||||||
| yield { | ||||||||||||||||||||||||||||||||||||||||||||||
| data: { | ||||||||||||||||||||||||||||||||||||||||||||||
| choices: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| delta: { | ||||||||||||||||||||||||||||||||||||||||||||||
| content: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "thinking", text: "Let me think about this..." }, | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "text", text: "Here's the answer" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| index: 0, | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return stream | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| const iterator = handler.createMessage(systemPrompt, messages) | ||||||||||||||||||||||||||||||||||||||||||||||
| const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for await (const chunk of iterator) { | ||||||||||||||||||||||||||||||||||||||||||||||
| if ("text" in chunk) { | ||||||||||||||||||||||||||||||||||||||||||||||
| results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk) | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| expect(results).toHaveLength(2) | ||||||||||||||||||||||||||||||||||||||||||||||
| expect(results[0]).toEqual({ type: "reasoning", text: "Let me think about this..." }) | ||||||||||||||||||||||||||||||||||||||||||||||
| expect(results[1]).toEqual({ type: "text", text: "Here's the answer" }) | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| it("should handle mixed content arrays correctly", async () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| // Mock stream with mixed content | ||||||||||||||||||||||||||||||||||||||||||||||
| mockCreate.mockImplementationOnce(async (_options) => { | ||||||||||||||||||||||||||||||||||||||||||||||
| const stream = { | ||||||||||||||||||||||||||||||||||||||||||||||
| [Symbol.asyncIterator]: async function* () { | ||||||||||||||||||||||||||||||||||||||||||||||
| yield { | ||||||||||||||||||||||||||||||||||||||||||||||
| data: { | ||||||||||||||||||||||||||||||||||||||||||||||
| choices: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| delta: { | ||||||||||||||||||||||||||||||||||||||||||||||
| content: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "text", text: "First text" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "thinking", text: "Some reasoning" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "text", text: "Second text" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| index: 0, | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return stream | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| const iterator = handler.createMessage(systemPrompt, messages) | ||||||||||||||||||||||||||||||||||||||||||||||
| const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for await (const chunk of iterator) { | ||||||||||||||||||||||||||||||||||||||||||||||
| if ("text" in chunk) { | ||||||||||||||||||||||||||||||||||||||||||||||
| results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk) | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| expect(results).toHaveLength(3) | ||||||||||||||||||||||||||||||||||||||||||||||
| expect(results[0]).toEqual({ type: "text", text: "First text" }) | ||||||||||||||||||||||||||||||||||||||||||||||
| expect(results[1]).toEqual({ type: "reasoning", text: "Some reasoning" }) | ||||||||||||||||||||||||||||||||||||||||||||||
| expect(results[2]).toEqual({ type: "text", text: "Second text" }) | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| describe("completePrompt", () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| it("should complete prompt successfully", async () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| const prompt = "Test prompt" | ||||||||||||||||||||||||||||||||||||||||||||||
| const result = await handler.completePrompt(prompt) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| expect(mockComplete).toHaveBeenCalledWith({ | ||||||||||||||||||||||||||||||||||||||||||||||
| model: mockOptions.apiModelId, | ||||||||||||||||||||||||||||||||||||||||||||||
| messages: [{ role: "user", content: prompt }], | ||||||||||||||||||||||||||||||||||||||||||||||
| temperature: 0, | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| expect(result).toBe("Test response") | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| it("should filter out thinking content in completePrompt", async () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| mockComplete.mockImplementationOnce(async (_options) => { | ||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||
| choices: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| message: { | ||||||||||||||||||||||||||||||||||||||||||||||
| content: [ | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "thinking", text: "Let me think..." }, | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "text", text: "Answer part 1" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| { type: "text", text: "Answer part 2" }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| const prompt = "Test prompt" | ||||||||||||||||||||||||||||||||||||||||||||||
| const result = await handler.completePrompt(prompt) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| expect(result).toBe("Answer part 1Answer part 2") | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a test for the edge case where ALL content is thinking content? This would verify the function returns an empty string correctly:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
| it("should handle errors in completePrompt", async () => { | ||||||||||||||||||||||||||||||||||||||||||||||
| mockComplete.mockRejectedValueOnce(new Error("API Error")) | ||||||||||||||||||||||||||||||||||||||||||||||
| await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error") | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,19 @@ import { ApiStream } from "../transform/stream" | |||||||||||
| import { BaseProvider } from "./base-provider" | ||||||||||||
| import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" | ||||||||||||
|
|
||||||||||||
| // Define TypeScript interfaces for Mistral content types | ||||||||||||
| interface MistralTextContent { | ||||||||||||
|
||||||||||||
| type: "text" | ||||||||||||
| text: string | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| interface MistralThinkingContent { | ||||||||||||
| type: "thinking" | ||||||||||||
| text: string | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| type MistralContent = MistralTextContent | MistralThinkingContent | string | ||||||||||||
|
|
||||||||||||
| export class MistralHandler extends BaseProvider implements SingleCompletionHandler { | ||||||||||||
| protected options: ApiHandlerOptions | ||||||||||||
| private client: Mistral | ||||||||||||
|
|
@@ -52,15 +65,23 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand | |||||||||||
| const delta = chunk.data.choices[0]?.delta | ||||||||||||
|
|
||||||||||||
| if (delta?.content) { | ||||||||||||
| let content: string = "" | ||||||||||||
|
|
||||||||||||
| if (typeof delta.content === "string") { | ||||||||||||
| content = delta.content | ||||||||||||
| // Handle string content as text | ||||||||||||
| yield { type: "text", text: delta.content } | ||||||||||||
| } else if (Array.isArray(delta.content)) { | ||||||||||||
| content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("") | ||||||||||||
| // Handle array of content blocks | ||||||||||||
| for (const c of delta.content as MistralContent[]) { | ||||||||||||
| if (typeof c === "object" && c !== null) { | ||||||||||||
| if (c.type === "thinking" && c.text) { | ||||||||||||
|
||||||||||||
| // Handle thinking content as reasoning chunks | ||||||||||||
| yield { type: "reasoning", text: c.text } | ||||||||||||
| } else if (c.type === "text" && c.text) { | ||||||||||||
| // Handle text content normally | ||||||||||||
| yield { type: "text", text: c.text } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| yield { type: "text", text: content } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| if (chunk.data.usage) { | ||||||||||||
|
|
@@ -97,7 +118,11 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand | |||||||||||
| const content = response.choices?.[0]?.message.content | ||||||||||||
|
|
||||||||||||
| if (Array.isArray(content)) { | ||||||||||||
| return content.map((c) => (c.type === "text" ? c.text : "")).join("") | ||||||||||||
| // Only return text content, filter out thinking content for non-streaming | ||||||||||||
| return content | ||||||||||||
| .filter((c: any) => typeof c === "object" && c !== null && c.type === "text") | ||||||||||||
|
||||||||||||
| .filter((c: any) => typeof c === "object" && c !== null && c.type === "text") | |
| return content | |
| .filter((c: MistralContent) => typeof c === "object" && c !== null && c.type === "text") | |
| .map((c: MistralTextContent) => c.text || "") | |
| .join("") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we test array content filtering in completePrompt, could we add an explicit test for when content is already a string (covering line 128 in mistral.ts)?