Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions src/api/providers/__tests__/base-openai-compatible-provider.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// npx vitest run api/providers/__tests__/base-openai-compatible-provider.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ModelInfo } from "@roo-code/types"

import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider"

// Create mock functions
const mockCreate = vi.fn()

// Mock OpenAI module
vi.mock("openai", () => ({
default: vi.fn(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}))

// Create a concrete test implementation of the abstract base class
class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> {
constructor(apiKey: string) {
const testModels: Record<"test-model", ModelInfo> = {
"test-model": {
maxTokens: 4096,
contextWindow: 128000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.5,
outputPrice: 1.5,
},
}

super({
providerName: "TestProvider",
baseURL: "https://test.example.com/v1",
defaultProviderModelId: "test-model",
providerModels: testModels,
apiKey,
})
}
}

describe("BaseOpenAiCompatibleProvider", () => {
let handler: TestOpenAiCompatibleProvider

beforeEach(() => {
vi.clearAllMocks()
handler = new TestOpenAiCompatibleProvider("test-api-key")
})

afterEach(() => {
vi.restoreAllMocks()
})

describe("XmlMatcher reasoning tags", () => {
it("should handle reasoning tags (<think>) from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "<think>Let me think" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " about this</think>" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "The answer is 42" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// XmlMatcher yields chunks as they're processed
expect(chunks).toEqual([
{ type: "reasoning", text: "Let me think" },
{ type: "reasoning", text: " about this" },
{ type: "text", text: "The answer is 42" },
])
})

it("should handle complete <think> tag in a single chunk", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "Regular text before " } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "<think>Complete thought</think>" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " regular text after" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// When a complete tag arrives in one chunk, XmlMatcher may not parse it
// This test documents the actual behavior
expect(chunks.length).toBeGreaterThan(0)
expect(chunks[0]).toEqual({ type: "text", text: "Regular text before " })
})

it("should handle incomplete <think> tag at end of stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "<think>Incomplete thought" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// XmlMatcher should handle incomplete tags and flush remaining content
expect(chunks.length).toBeGreaterThan(0)
expect(
chunks.some(
(c) => (c.type === "text" || c.type === "reasoning") && c.text.includes("Incomplete thought"),
),
).toBe(true)
})

it("should handle text without any <think> tags", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "Just regular text" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " without reasoning" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Just regular text" },
{ type: "text", text: " without reasoning" },
])
})

it("should handle <think> tags that start at beginning of stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "<think>reasoning" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " content</think>" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " normal text" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "reasoning", text: "reasoning" },
{ type: "reasoning", text: " content" },
{ type: "text", text: " normal text" },
])
})
})

describe("Basic functionality", () => {
it("should create stream with correct parameters", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

const messageGenerator = handler.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "test-model",
temperature: 0,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
undefined,
)
})

it("should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: {} }],
usage: { prompt_tokens: 100, completion_tokens: 50 },
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 })
})
})
})
37 changes: 0 additions & 37 deletions src/api/providers/__tests__/minimax.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,43 +178,6 @@ describe("MiniMaxHandler", () => {
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

it("should handle reasoning tags (<think>) from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "<think>Let me think" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: " about this</think>" } }] },
})
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: "The answer is 42" } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// XmlMatcher yields chunks as they're processed
expect(chunks).toEqual([
{ type: "reasoning", text: "Let me think" },
{ type: "reasoning", text: " about this" },
{ type: "text", text: "The answer is 42" },
])
})

it("createMessage should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
Expand Down
20 changes: 17 additions & 3 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import OpenAI from "openai"
import type { ModelInfo } from "@roo-code/types"

import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
import { XmlMatcher } from "../../utils/xml-matcher"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"

Expand Down Expand Up @@ -105,13 +106,21 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)

const matcher = new XmlMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
}
}

Expand All @@ -127,6 +136,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}
}

// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
}

async completePrompt(prompt: string): Promise<string> {
Expand Down
Loading