Skip to content

Commit aada7cc

Browse files
committed
feat: add GLM-4.6 thinking token support for OpenAI-compatible endpoints
- Add detection for GLM-4.6 model variants - Include thinking parameter { type: "enabled" } in requests for GLM-4.6 - Parse thinking tokens using XmlMatcher for <think> tags - Handle reasoning_content in streaming responses - Add comprehensive tests for GLM-4.6 functionality Fixes #8547
1 parent 5a3f911 commit aada7cc

File tree

2 files changed

+330
-3
lines changed

2 files changed

+330
-3
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import { describe, it, expect, vi, beforeEach, type Mock } from "vitest"
2+
import OpenAI from "openai"
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
5+
import type { ModelInfo } from "@roo-code/types"
6+
import type { ApiHandlerOptions } from "../../../shared/api"
7+
8+
import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider"
9+
10+
// Mock OpenAI module
11+
vi.mock("openai", () => {
12+
const mockCreate = vi.fn()
13+
const MockOpenAI = vi.fn().mockImplementation(() => ({
14+
chat: {
15+
completions: {
16+
create: mockCreate,
17+
},
18+
},
19+
}))
20+
return { default: MockOpenAI }
21+
})
22+
23+
// Create a concrete implementation for testing
24+
class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model" | "glm-4.6"> {
25+
constructor(options: ApiHandlerOptions) {
26+
super({
27+
...options,
28+
providerName: "TestProvider",
29+
baseURL: options.openAiBaseUrl || "https://api.test.com/v1",
30+
defaultProviderModelId: "test-model",
31+
providerModels: {
32+
"test-model": {
33+
maxTokens: 4096,
34+
contextWindow: 8192,
35+
supportsImages: false,
36+
supportsPromptCache: false,
37+
inputPrice: 0.01,
38+
outputPrice: 0.02,
39+
},
40+
"glm-4.6": {
41+
maxTokens: 8192,
42+
contextWindow: 128000,
43+
supportsImages: true,
44+
supportsPromptCache: false,
45+
inputPrice: 0.015,
46+
outputPrice: 0.03,
47+
},
48+
},
49+
})
50+
}
51+
}
52+
53+
describe("BaseOpenAiCompatibleProvider", () => {
54+
let provider: TestOpenAiCompatibleProvider
55+
let mockOpenAIInstance: any
56+
let mockCreate: Mock
57+
58+
beforeEach(() => {
59+
vi.clearAllMocks()
60+
mockOpenAIInstance = new (OpenAI as any)()
61+
mockCreate = mockOpenAIInstance.chat.completions.create
62+
})
63+
64+
describe("GLM-4.6 thinking token support", () => {
65+
it("should detect GLM-4.6 model correctly", () => {
66+
provider = new TestOpenAiCompatibleProvider({
67+
apiKey: "test-key",
68+
apiModelId: "glm-4.6",
69+
})
70+
71+
// Test the isGLM46Model method
72+
expect((provider as any).isGLM46Model("glm-4.6")).toBe(true)
73+
expect((provider as any).isGLM46Model("GLM-4.6")).toBe(true)
74+
expect((provider as any).isGLM46Model("glm-4-6")).toBe(true)
75+
expect((provider as any).isGLM46Model("GLM-4-6")).toBe(true)
76+
expect((provider as any).isGLM46Model("test-model")).toBe(false)
77+
expect((provider as any).isGLM46Model("gpt-4")).toBe(false)
78+
})
79+
80+
it("should add thinking parameter for GLM-4.6 model", async () => {
81+
provider = new TestOpenAiCompatibleProvider({
82+
apiKey: "test-key",
83+
apiModelId: "glm-4.6",
84+
})
85+
86+
// Mock the stream response
87+
const mockStream = {
88+
async *[Symbol.asyncIterator]() {
89+
yield {
90+
choices: [{ delta: { content: "Test response" } }],
91+
usage: { prompt_tokens: 10, completion_tokens: 5 },
92+
}
93+
},
94+
}
95+
mockCreate.mockResolvedValue(mockStream)
96+
97+
// Create a message
98+
const systemPrompt = "You are a helpful assistant"
99+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
100+
101+
const stream = provider.createMessage(systemPrompt, messages)
102+
const results = []
103+
for await (const chunk of stream) {
104+
results.push(chunk)
105+
}
106+
107+
// Verify that the create method was called with thinking parameter
108+
expect(mockCreate).toHaveBeenCalledWith(
109+
expect.objectContaining({
110+
model: "glm-4.6",
111+
thinking: { type: "enabled" },
112+
stream: true,
113+
}),
114+
undefined,
115+
)
116+
})
117+
118+
it("should not add thinking parameter for non-GLM-4.6 models", async () => {
119+
provider = new TestOpenAiCompatibleProvider({
120+
apiKey: "test-key",
121+
apiModelId: "test-model",
122+
})
123+
124+
// Mock the stream response
125+
const mockStream = {
126+
async *[Symbol.asyncIterator]() {
127+
yield {
128+
choices: [{ delta: { content: "Test response" } }],
129+
usage: { prompt_tokens: 10, completion_tokens: 5 },
130+
}
131+
},
132+
}
133+
mockCreate.mockResolvedValue(mockStream)
134+
135+
// Create a message
136+
const systemPrompt = "You are a helpful assistant"
137+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
138+
139+
const stream = provider.createMessage(systemPrompt, messages)
140+
const results = []
141+
for await (const chunk of stream) {
142+
results.push(chunk)
143+
}
144+
145+
// Verify that the create method was called without thinking parameter
146+
expect(mockCreate).toHaveBeenCalledWith(
147+
expect.not.objectContaining({
148+
thinking: expect.anything(),
149+
}),
150+
undefined,
151+
)
152+
})
153+
154+
it("should parse thinking tokens from GLM-4.6 response", async () => {
155+
provider = new TestOpenAiCompatibleProvider({
156+
apiKey: "test-key",
157+
apiModelId: "glm-4.6",
158+
})
159+
160+
// Mock the stream response with thinking tokens
161+
const mockStream = {
162+
async *[Symbol.asyncIterator]() {
163+
yield { choices: [{ delta: { content: "<think>" } }], usage: null }
164+
yield { choices: [{ delta: { content: "Let me analyze this problem..." } }], usage: null }
165+
yield { choices: [{ delta: { content: "</think>" } }], usage: null }
166+
yield { choices: [{ delta: { content: "The answer is 42." } }], usage: null }
167+
yield { choices: [], usage: { prompt_tokens: 10, completion_tokens: 20 } }
168+
},
169+
}
170+
mockCreate.mockResolvedValue(mockStream)
171+
172+
// Create a message
173+
const systemPrompt = "You are a helpful assistant"
174+
const messages: Anthropic.Messages.MessageParam[] = [
175+
{ role: "user", content: "What is the meaning of life?" },
176+
]
177+
178+
const stream = provider.createMessage(systemPrompt, messages)
179+
const results = []
180+
for await (const chunk of stream) {
181+
results.push(chunk)
182+
}
183+
184+
// Verify that thinking tokens were parsed correctly
185+
const reasoningChunks = results.filter((r) => r.type === "reasoning")
186+
const textChunks = results.filter((r) => r.type === "text")
187+
188+
expect(reasoningChunks.length).toBeGreaterThan(0)
189+
expect(reasoningChunks.some((c) => c.text?.includes("Let me analyze this problem"))).toBe(true)
190+
expect(textChunks.some((c) => c.text === "The answer is 42.")).toBe(true)
191+
})
192+
193+
it("should handle reasoning_content in delta for models that support it", async () => {
194+
provider = new TestOpenAiCompatibleProvider({
195+
apiKey: "test-key",
196+
apiModelId: "glm-4.6",
197+
})
198+
199+
// Mock the stream response with reasoning_content
200+
const mockStream = {
201+
async *[Symbol.asyncIterator]() {
202+
yield { choices: [{ delta: { reasoning_content: "Thinking about the problem..." } }], usage: null }
203+
yield { choices: [{ delta: { content: "The solution is simple." } }], usage: null }
204+
yield { choices: [], usage: { prompt_tokens: 10, completion_tokens: 15 } }
205+
},
206+
}
207+
mockCreate.mockResolvedValue(mockStream)
208+
209+
// Create a message
210+
const systemPrompt = "You are a helpful assistant"
211+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Solve this problem" }]
212+
213+
const stream = provider.createMessage(systemPrompt, messages)
214+
const results = []
215+
for await (const chunk of stream) {
216+
results.push(chunk)
217+
}
218+
219+
// Verify that reasoning_content was handled correctly
220+
const reasoningChunks = results.filter((r) => r.type === "reasoning")
221+
const textChunks = results.filter((r) => r.type === "text")
222+
223+
expect(reasoningChunks.some((c) => c.text === "Thinking about the problem...")).toBe(true)
224+
expect(textChunks.some((c) => c.text === "The solution is simple.")).toBe(true)
225+
})
226+
})
227+
228+
describe("completePrompt", () => {
229+
it("should complete prompt successfully", async () => {
230+
provider = new TestOpenAiCompatibleProvider({
231+
apiKey: "test-key",
232+
apiModelId: "test-model",
233+
})
234+
235+
const mockResponse = {
236+
choices: [{ message: { content: "Completed response" } }],
237+
}
238+
mockCreate.mockResolvedValue(mockResponse)
239+
240+
const result = await provider.completePrompt("Test prompt")
241+
242+
expect(result).toBe("Completed response")
243+
expect(mockCreate).toHaveBeenCalledWith({
244+
model: "test-model",
245+
messages: [{ role: "user", content: "Test prompt" }],
246+
})
247+
})
248+
})
249+
250+
describe("getModel", () => {
251+
it("should return correct model info", () => {
252+
provider = new TestOpenAiCompatibleProvider({
253+
apiKey: "test-key",
254+
apiModelId: "glm-4.6",
255+
})
256+
257+
const model = provider.getModel()
258+
259+
expect(model.id).toBe("glm-4.6")
260+
expect(model.info.maxTokens).toBe(8192)
261+
expect(model.info.contextWindow).toBe(128000)
262+
})
263+
264+
it("should use default model when apiModelId is not provided", () => {
265+
provider = new TestOpenAiCompatibleProvider({
266+
apiKey: "test-key",
267+
})
268+
269+
const model = provider.getModel()
270+
271+
expect(model.id).toBe("test-model")
272+
expect(model.info.maxTokens).toBe(4096)
273+
})
274+
})
275+
})

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import type { ModelInfo } from "@roo-code/types"
66
import type { ApiHandlerOptions } from "../../shared/api"
77
import { ApiStream } from "../transform/stream"
88
import { convertToOpenAiMessages } from "../transform/openai-format"
9+
import { XmlMatcher } from "../../utils/xml-matcher"
910

1011
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
1112
import { DEFAULT_HEADERS } from "./constants"
@@ -85,6 +86,12 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
8586
stream_options: { include_usage: true },
8687
}
8788

89+
// Add thinking parameter for GLM-4.6 model
90+
if (this.isGLM46Model(model)) {
91+
// @ts-ignore - GLM-4.6 specific parameter
92+
params.thinking = { type: "enabled" }
93+
}
94+
8895
try {
8996
return this.client.chat.completions.create(params, requestOptions)
9097
} catch (error) {
@@ -98,14 +105,43 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
98105
metadata?: ApiHandlerCreateMessageMetadata,
99106
): ApiStream {
100107
const stream = await this.createStream(systemPrompt, messages, metadata)
108+
const { id: model } = this.getModel()
109+
const isGLM46 = this.isGLM46Model(model)
110+
111+
// Use XmlMatcher for GLM-4.6 to parse thinking tokens
112+
const matcher = isGLM46
113+
? new XmlMatcher(
114+
"think",
115+
(chunk) =>
116+
({
117+
type: chunk.matched ? "reasoning" : "text",
118+
text: chunk.data,
119+
}) as const,
120+
)
121+
: null
101122

102123
for await (const chunk of stream) {
103-
const delta = chunk.choices[0]?.delta
124+
const delta = chunk.choices?.[0]?.delta
104125

105126
if (delta?.content) {
127+
if (isGLM46 && matcher) {
128+
// Parse thinking tokens for GLM-4.6
129+
for (const parsedChunk of matcher.update(delta.content)) {
130+
yield parsedChunk
131+
}
132+
} else {
133+
yield {
134+
type: "text",
135+
text: delta.content,
136+
}
137+
}
138+
}
139+
140+
// Handle reasoning_content if present (for models that support it directly)
141+
if (delta && "reasoning_content" in delta && delta.reasoning_content) {
106142
yield {
107-
type: "text",
108-
text: delta.content,
143+
type: "reasoning",
144+
text: (delta.reasoning_content as string | undefined) || "",
109145
}
110146
}
111147

@@ -117,6 +153,13 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
117153
}
118154
}
119155
}
156+
157+
// Finalize any remaining content from the matcher
158+
if (isGLM46 && matcher) {
159+
for (const parsedChunk of matcher.final()) {
160+
yield parsedChunk
161+
}
162+
}
120163
}
121164

122165
async completePrompt(prompt: string): Promise<string> {
@@ -142,4 +185,13 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
142185

143186
return { id, info: this.providerModels[id] }
144187
}
188+
189+
/**
190+
* Check if the model is GLM-4.6 which requires thinking parameter
191+
*/
192+
protected isGLM46Model(modelId: string): boolean {
193+
// Check for various GLM-4.6 model naming patterns
194+
const lowerModel = modelId.toLowerCase()
195+
return lowerModel.includes("glm-4.6") || lowerModel.includes("glm-4-6") || lowerModel === "glm-4.6"
196+
}
145197
}

0 commit comments

Comments
 (0)