Skip to content

Commit 361532e

Browse files
mrubensPrasangAPrajapati
authored andcommitted
Handle <think> tags in the base OpenAI-compatible provider (RooCodeInc#8989)
1 parent 88f4710 commit 361532e

File tree

4 files changed

+303
-83
lines changed

4 files changed

+303
-83
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
// npx vitest run api/providers/__tests__/base-openai-compatible-provider.spec.ts
2+
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
import OpenAI from "openai"
5+
6+
import type { ModelInfo } from "@roo-code/types"
7+
8+
import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider"
9+
10+
// Create mock functions
11+
const mockCreate = vi.fn()
12+
13+
// Mock OpenAI module
14+
vi.mock("openai", () => ({
15+
default: vi.fn(() => ({
16+
chat: {
17+
completions: {
18+
create: mockCreate,
19+
},
20+
},
21+
})),
22+
}))
23+
24+
// Create a concrete test implementation of the abstract base class
25+
class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> {
26+
constructor(apiKey: string) {
27+
const testModels: Record<"test-model", ModelInfo> = {
28+
"test-model": {
29+
maxTokens: 4096,
30+
contextWindow: 128000,
31+
supportsImages: false,
32+
supportsPromptCache: false,
33+
inputPrice: 0.5,
34+
outputPrice: 1.5,
35+
},
36+
}
37+
38+
super({
39+
providerName: "TestProvider",
40+
baseURL: "https://test.example.com/v1",
41+
defaultProviderModelId: "test-model",
42+
providerModels: testModels,
43+
apiKey,
44+
})
45+
}
46+
}
47+
48+
describe("BaseOpenAiCompatibleProvider", () => {
49+
let handler: TestOpenAiCompatibleProvider
50+
51+
beforeEach(() => {
52+
vi.clearAllMocks()
53+
handler = new TestOpenAiCompatibleProvider("test-api-key")
54+
})
55+
56+
afterEach(() => {
57+
vi.restoreAllMocks()
58+
})
59+
60+
describe("XmlMatcher reasoning tags", () => {
61+
it("should handle reasoning tags (<think>) from stream", async () => {
62+
mockCreate.mockImplementationOnce(() => {
63+
return {
64+
[Symbol.asyncIterator]: () => ({
65+
next: vi
66+
.fn()
67+
.mockResolvedValueOnce({
68+
done: false,
69+
value: { choices: [{ delta: { content: "<think>Let me think" } }] },
70+
})
71+
.mockResolvedValueOnce({
72+
done: false,
73+
value: { choices: [{ delta: { content: " about this</think>" } }] },
74+
})
75+
.mockResolvedValueOnce({
76+
done: false,
77+
value: { choices: [{ delta: { content: "The answer is 42" } }] },
78+
})
79+
.mockResolvedValueOnce({ done: true }),
80+
}),
81+
}
82+
})
83+
84+
const stream = handler.createMessage("system prompt", [])
85+
const chunks = []
86+
for await (const chunk of stream) {
87+
chunks.push(chunk)
88+
}
89+
90+
// XmlMatcher yields chunks as they're processed
91+
expect(chunks).toEqual([
92+
{ type: "reasoning", text: "Let me think" },
93+
{ type: "reasoning", text: " about this" },
94+
{ type: "text", text: "The answer is 42" },
95+
])
96+
})
97+
98+
it("should handle complete <think> tag in a single chunk", async () => {
99+
mockCreate.mockImplementationOnce(() => {
100+
return {
101+
[Symbol.asyncIterator]: () => ({
102+
next: vi
103+
.fn()
104+
.mockResolvedValueOnce({
105+
done: false,
106+
value: { choices: [{ delta: { content: "Regular text before " } }] },
107+
})
108+
.mockResolvedValueOnce({
109+
done: false,
110+
value: { choices: [{ delta: { content: "<think>Complete thought</think>" } }] },
111+
})
112+
.mockResolvedValueOnce({
113+
done: false,
114+
value: { choices: [{ delta: { content: " regular text after" } }] },
115+
})
116+
.mockResolvedValueOnce({ done: true }),
117+
}),
118+
}
119+
})
120+
121+
const stream = handler.createMessage("system prompt", [])
122+
const chunks = []
123+
for await (const chunk of stream) {
124+
chunks.push(chunk)
125+
}
126+
127+
// When a complete tag arrives in one chunk, XmlMatcher may not parse it
128+
// This test documents the actual behavior
129+
expect(chunks.length).toBeGreaterThan(0)
130+
expect(chunks[0]).toEqual({ type: "text", text: "Regular text before " })
131+
})
132+
133+
it("should handle incomplete <think> tag at end of stream", async () => {
134+
mockCreate.mockImplementationOnce(() => {
135+
return {
136+
[Symbol.asyncIterator]: () => ({
137+
next: vi
138+
.fn()
139+
.mockResolvedValueOnce({
140+
done: false,
141+
value: { choices: [{ delta: { content: "<think>Incomplete thought" } }] },
142+
})
143+
.mockResolvedValueOnce({ done: true }),
144+
}),
145+
}
146+
})
147+
148+
const stream = handler.createMessage("system prompt", [])
149+
const chunks = []
150+
for await (const chunk of stream) {
151+
chunks.push(chunk)
152+
}
153+
154+
// XmlMatcher should handle incomplete tags and flush remaining content
155+
expect(chunks.length).toBeGreaterThan(0)
156+
expect(
157+
chunks.some(
158+
(c) => (c.type === "text" || c.type === "reasoning") && c.text.includes("Incomplete thought"),
159+
),
160+
).toBe(true)
161+
})
162+
163+
it("should handle text without any <think> tags", async () => {
164+
mockCreate.mockImplementationOnce(() => {
165+
return {
166+
[Symbol.asyncIterator]: () => ({
167+
next: vi
168+
.fn()
169+
.mockResolvedValueOnce({
170+
done: false,
171+
value: { choices: [{ delta: { content: "Just regular text" } }] },
172+
})
173+
.mockResolvedValueOnce({
174+
done: false,
175+
value: { choices: [{ delta: { content: " without reasoning" } }] },
176+
})
177+
.mockResolvedValueOnce({ done: true }),
178+
}),
179+
}
180+
})
181+
182+
const stream = handler.createMessage("system prompt", [])
183+
const chunks = []
184+
for await (const chunk of stream) {
185+
chunks.push(chunk)
186+
}
187+
188+
expect(chunks).toEqual([
189+
{ type: "text", text: "Just regular text" },
190+
{ type: "text", text: " without reasoning" },
191+
])
192+
})
193+
194+
it("should handle <think> tags that start at beginning of stream", async () => {
195+
mockCreate.mockImplementationOnce(() => {
196+
return {
197+
[Symbol.asyncIterator]: () => ({
198+
next: vi
199+
.fn()
200+
.mockResolvedValueOnce({
201+
done: false,
202+
value: { choices: [{ delta: { content: "<think>reasoning" } }] },
203+
})
204+
.mockResolvedValueOnce({
205+
done: false,
206+
value: { choices: [{ delta: { content: " content</think>" } }] },
207+
})
208+
.mockResolvedValueOnce({
209+
done: false,
210+
value: { choices: [{ delta: { content: " normal text" } }] },
211+
})
212+
.mockResolvedValueOnce({ done: true }),
213+
}),
214+
}
215+
})
216+
217+
const stream = handler.createMessage("system prompt", [])
218+
const chunks = []
219+
for await (const chunk of stream) {
220+
chunks.push(chunk)
221+
}
222+
223+
expect(chunks).toEqual([
224+
{ type: "reasoning", text: "reasoning" },
225+
{ type: "reasoning", text: " content" },
226+
{ type: "text", text: " normal text" },
227+
])
228+
})
229+
})
230+
231+
describe("Basic functionality", () => {
232+
it("should create stream with correct parameters", async () => {
233+
mockCreate.mockImplementationOnce(() => {
234+
return {
235+
[Symbol.asyncIterator]: () => ({
236+
async next() {
237+
return { done: true }
238+
},
239+
}),
240+
}
241+
})
242+
243+
const systemPrompt = "Test system prompt"
244+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
245+
246+
const messageGenerator = handler.createMessage(systemPrompt, messages)
247+
await messageGenerator.next()
248+
249+
expect(mockCreate).toHaveBeenCalledWith(
250+
expect.objectContaining({
251+
model: "test-model",
252+
temperature: 0,
253+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
254+
stream: true,
255+
stream_options: { include_usage: true },
256+
}),
257+
undefined,
258+
)
259+
})
260+
261+
it("should yield usage data from stream", async () => {
262+
mockCreate.mockImplementationOnce(() => {
263+
return {
264+
[Symbol.asyncIterator]: () => ({
265+
next: vi
266+
.fn()
267+
.mockResolvedValueOnce({
268+
done: false,
269+
value: {
270+
choices: [{ delta: {} }],
271+
usage: { prompt_tokens: 100, completion_tokens: 50 },
272+
},
273+
})
274+
.mockResolvedValueOnce({ done: true }),
275+
}),
276+
}
277+
})
278+
279+
const stream = handler.createMessage("system prompt", [])
280+
const firstChunk = await stream.next()
281+
282+
expect(firstChunk.done).toBe(false)
283+
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 })
284+
})
285+
})
286+
})

src/api/providers/__tests__/minimax.spec.ts

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -178,43 +178,6 @@ describe("MiniMaxHandler", () => {
178178
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
179179
})
180180

181-
it("should handle reasoning tags (<think>) from stream", async () => {
182-
mockCreate.mockImplementationOnce(() => {
183-
return {
184-
[Symbol.asyncIterator]: () => ({
185-
next: vitest
186-
.fn()
187-
.mockResolvedValueOnce({
188-
done: false,
189-
value: { choices: [{ delta: { content: "<think>Let me think" } }] },
190-
})
191-
.mockResolvedValueOnce({
192-
done: false,
193-
value: { choices: [{ delta: { content: " about this</think>" } }] },
194-
})
195-
.mockResolvedValueOnce({
196-
done: false,
197-
value: { choices: [{ delta: { content: "The answer is 42" } }] },
198-
})
199-
.mockResolvedValueOnce({ done: true }),
200-
}),
201-
}
202-
})
203-
204-
const stream = handler.createMessage("system prompt", [])
205-
const chunks = []
206-
for await (const chunk of stream) {
207-
chunks.push(chunk)
208-
}
209-
210-
// XmlMatcher yields chunks as they're processed
211-
expect(chunks).toEqual([
212-
{ type: "reasoning", text: "Let me think" },
213-
{ type: "reasoning", text: " about this" },
214-
{ type: "text", text: "The answer is 42" },
215-
])
216-
})
217-
218181
it("createMessage should yield usage data from stream", async () => {
219182
mockCreate.mockImplementationOnce(() => {
220183
return {

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import OpenAI from "openai"
44
import type { ModelInfo } from "@roo-code/types"
55

66
import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
7+
import { XmlMatcher } from "../../utils/xml-matcher"
78
import { ApiStream } from "../transform/stream"
89
import { convertToOpenAiMessages } from "../transform/openai-format"
910

@@ -105,13 +106,21 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
105106
): ApiStream {
106107
const stream = await this.createStream(systemPrompt, messages, metadata)
107108

109+
const matcher = new XmlMatcher(
110+
"think",
111+
(chunk) =>
112+
({
113+
type: chunk.matched ? "reasoning" : "text",
114+
text: chunk.data,
115+
}) as const,
116+
)
117+
108118
for await (const chunk of stream) {
109119
const delta = chunk.choices[0]?.delta
110120

111121
if (delta?.content) {
112-
yield {
113-
type: "text",
114-
text: delta.content,
122+
for (const processedChunk of matcher.update(delta.content)) {
123+
yield processedChunk
115124
}
116125
}
117126

@@ -127,6 +136,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
127136
}
128137
}
129138
}
139+
140+
// Process any remaining content
141+
for (const processedChunk of matcher.final()) {
142+
yield processedChunk
143+
}
130144
}
131145

132146
async completePrompt(prompt: string): Promise<string> {

0 commit comments

Comments
 (0)