diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts new file mode 100644 index 00000000000..548019020db --- /dev/null +++ b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts @@ -0,0 +1,286 @@ +// npx vitest run api/providers/__tests__/base-openai-compatible-provider.spec.ts + +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import type { ModelInfo } from "@roo-code/types" + +import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider" + +// Create mock functions +const mockCreate = vi.fn() + +// Mock OpenAI module +vi.mock("openai", () => ({ + default: vi.fn(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), +})) + +// Create a concrete test implementation of the abstract base class +class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> { + constructor(apiKey: string) { + const testModels: Record<"test-model", ModelInfo> = { + "test-model": { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.5, + outputPrice: 1.5, + }, + } + + super({ + providerName: "TestProvider", + baseURL: "https://test.example.com/v1", + defaultProviderModelId: "test-model", + providerModels: testModels, + apiKey, + }) + } +} + +describe("BaseOpenAiCompatibleProvider", () => { + let handler: TestOpenAiCompatibleProvider + + beforeEach(() => { + vi.clearAllMocks() + handler = new TestOpenAiCompatibleProvider("test-api-key") + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe("XmlMatcher reasoning tags", () => { + it("should handle reasoning tags () from stream", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "Let me think" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: " about this" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "The answer is 42" } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // XmlMatcher yields chunks as they're processed + expect(chunks).toEqual([ + { type: "reasoning", text: "Let me think" }, + { type: "reasoning", text: " about this" }, + { type: "text", text: "The answer is 42" }, + ]) + }) + + it("should handle complete tag in a single chunk", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "Regular text before " } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "Complete thought" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: " regular text after" } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // When a complete tag arrives in one chunk, XmlMatcher may not parse it + // This test documents the actual behavior + expect(chunks.length).toBeGreaterThan(0) + expect(chunks[0]).toEqual({ type: "text", text: "Regular text before " }) + }) + + it("should handle incomplete tag at end of stream", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "Incomplete thought" } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // XmlMatcher should handle incomplete tags and flush remaining content + expect(chunks.length).toBeGreaterThan(0) + expect( + chunks.some( + (c) => (c.type === "text" || c.type === "reasoning") && c.text.includes("Incomplete thought"), + ), + ).toBe(true) + }) + + it("should handle text without any tags", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "Just regular text" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: " without reasoning" } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "text", text: "Just regular text" }, + { type: "text", text: " without reasoning" }, + ]) + }) + + it("should handle tags that start at beginning of stream", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: "reasoning" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: " content" } }] }, + }) + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: " normal text" } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "reasoning", text: "reasoning" }, + { type: "reasoning", text: " content" }, + { type: "text", text: " normal text" }, + ]) + }) + }) + + describe("Basic functionality", () => { + it("should create stream with correct parameters", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const systemPrompt = "Test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] + + const messageGenerator = handler.createMessage(systemPrompt, messages) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "test-model", + temperature: 0, + messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), + stream: true, + stream_options: { include_usage: true }, + }), + undefined, + ) + }) + + it("should yield usage data from stream", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [{ delta: {} }], + usage: { prompt_tokens: 100, completion_tokens: 50 }, + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 }) + }) + }) +}) diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index c488aea8812..1033626d0ea 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -178,43 +178,6 @@ describe("MiniMaxHandler", () => { expect(firstChunk.value).toEqual({ type: "text", text: testContent }) }) - it("should handle reasoning tags () from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "Let me think" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: " about this" } }] }, - }) - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: "The answer is 42" } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // XmlMatcher yields chunks as they're processed - expect(chunks).toEqual([ - { type: "reasoning", text: "Let me think" }, - { type: "reasoning", text: " about this" }, - { type: "text", text: "The answer is 42" }, - ]) - }) - it("createMessage should yield usage data from stream", async () => { mockCreate.mockImplementationOnce(() => { return { diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index ec62e7d7ddd..9ac00b0293c 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -4,6 +4,7 @@ import OpenAI from "openai" import type { ModelInfo } from "@roo-code/types" import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api" +import { XmlMatcher } from "../../utils/xml-matcher" import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" @@ -105,13 +106,21 @@ export abstract class BaseOpenAiCompatibleProvider ): ApiStream { const stream = await this.createStream(systemPrompt, messages, metadata) + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) + for await (const chunk of stream) { const delta = chunk.choices[0]?.delta if (delta?.content) { - yield { - type: "text", - text: delta.content, + for (const processedChunk of matcher.update(delta.content)) { + yield processedChunk } } @@ -127,6 +136,11 @@ export abstract class BaseOpenAiCompatibleProvider } } } + + // Process any remaining content + for (const processedChunk of matcher.final()) { + yield processedChunk + } } async completePrompt(prompt: string): Promise { diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 23722f59764..8a8e8c14e5b 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -1,10 +1,6 @@ -import { Anthropic } from "@anthropic-ai/sdk" import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" -import { ApiStream } from "../transform/stream" -import type { ApiHandlerCreateMessageMetadata } from "../index" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" @@ -20,43 +16,4 @@ export class MiniMaxHandler extends BaseOpenAiCompatibleProvider defaultTemperature: 1.0, }) } - - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const stream = await this.createStream(systemPrompt, messages, metadata) - - const matcher = new XmlMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - - if (delta?.content) { - for (const matcherChunk of matcher.update(delta.content)) { - yield matcherChunk - } - } - - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } - } - } - - for (const chunk of matcher.final()) { - yield chunk - } - } }