Skip to content

Commit 667a79b

Browse files
committed
fix: improve GLM-4.5 model handling to prevent hallucination and enhance tool understanding
- Add GLM-specific system prompt enhancements to prevent file hallucination - Include clear instructions for tool usage protocol and content management - Implement message preprocessing for better GLM model understanding - Add token limit adjustments and model-specific parameters for GLM-4.5 - Enhance completePrompt method with instruction prefix for GLM models - Add comprehensive tests for GLM-specific functionality Fixes #6942
1 parent 76e5a72 commit 667a79b

File tree

2 files changed

+336
-3
lines changed

2 files changed

+336
-3
lines changed

src/api/providers/__tests__/zai.spec.ts

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ describe("ZAiHandler", () => {
193193

194194
it("createMessage should pass correct parameters to Z AI client", async () => {
195195
const modelId: InternationalZAiModelId = "glm-4.5"
196-
const modelInfo = internationalZAiModels[modelId]
197196
const handlerWithModel = new ZAiHandler({
198197
apiModelId: modelId,
199198
zaiApiKey: "test-zai-api-key",
@@ -216,14 +215,159 @@ describe("ZAiHandler", () => {
216215
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
217216
await messageGenerator.next()
218217

218+
// For GLM-4.5, expect enhanced system prompt and adjusted parameters
219219
expect(mockCreate).toHaveBeenCalledWith(
220220
expect.objectContaining({
221221
model: modelId,
222-
max_tokens: modelInfo.maxTokens,
222+
max_tokens: 32768, // Adjusted for GLM models
223223
temperature: ZAI_DEFAULT_TEMPERATURE,
224-
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
224+
messages: expect.arrayContaining([
225+
{
226+
role: "system",
227+
content: expect.stringContaining(systemPrompt), // Contains original prompt plus enhancements
228+
},
229+
]),
225230
stream: true,
226231
stream_options: { include_usage: true },
232+
top_p: 0.95,
233+
frequency_penalty: 0.1,
234+
presence_penalty: 0.1,
235+
}),
236+
)
237+
})
238+
239+
it("should enhance system prompt for GLM-4.5 models", async () => {
240+
const modelId: InternationalZAiModelId = "glm-4.5"
241+
const handlerWithGLM = new ZAiHandler({
242+
apiModelId: modelId,
243+
zaiApiKey: "test-zai-api-key",
244+
zaiApiLine: "international",
245+
})
246+
247+
mockCreate.mockImplementationOnce(() => {
248+
return {
249+
[Symbol.asyncIterator]: () => ({
250+
async next() {
251+
return { done: true }
252+
},
253+
}),
254+
}
255+
})
256+
257+
const systemPrompt = "Test system prompt"
258+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
259+
260+
const messageGenerator = handlerWithGLM.createMessage(systemPrompt, messages)
261+
await messageGenerator.next()
262+
263+
// Check that the system prompt was enhanced with GLM-specific instructions
264+
expect(mockCreate).toHaveBeenCalledWith(
265+
expect.objectContaining({
266+
messages: expect.arrayContaining([
267+
{
268+
role: "system",
269+
content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"),
270+
},
271+
]),
272+
}),
273+
)
274+
})
275+
276+
it("should apply max token adjustment for GLM-4.5 models", async () => {
277+
const modelId: InternationalZAiModelId = "glm-4.5"
278+
const handlerWithGLM = new ZAiHandler({
279+
apiModelId: modelId,
280+
zaiApiKey: "test-zai-api-key",
281+
zaiApiLine: "international",
282+
})
283+
284+
mockCreate.mockImplementationOnce(() => {
285+
return {
286+
[Symbol.asyncIterator]: () => ({
287+
async next() {
288+
return { done: true }
289+
},
290+
}),
291+
}
292+
})
293+
294+
const messageGenerator = handlerWithGLM.createMessage("system", [])
295+
await messageGenerator.next()
296+
297+
// Check that max_tokens is capped at 32768 for GLM models
298+
expect(mockCreate).toHaveBeenCalledWith(
299+
expect.objectContaining({
300+
max_tokens: 32768,
301+
top_p: 0.95,
302+
frequency_penalty: 0.1,
303+
presence_penalty: 0.1,
304+
}),
305+
)
306+
})
307+
308+
it("should enhance prompt in completePrompt for GLM-4.5 models", async () => {
309+
const modelId: InternationalZAiModelId = "glm-4.5"
310+
const handlerWithGLM = new ZAiHandler({
311+
apiModelId: modelId,
312+
zaiApiKey: "test-zai-api-key",
313+
zaiApiLine: "international",
314+
})
315+
316+
const expectedResponse = "Test response"
317+
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
318+
319+
const testPrompt = "Test prompt"
320+
await handlerWithGLM.completePrompt(testPrompt)
321+
322+
// Check that the prompt was enhanced with GLM-specific prefix
323+
expect(mockCreate).toHaveBeenCalledWith(
324+
expect.objectContaining({
325+
messages: [
326+
{
327+
role: "user",
328+
content: expect.stringContaining(
329+
"[INSTRUCTION] Please provide a direct and accurate response",
330+
),
331+
},
332+
],
333+
temperature: ZAI_DEFAULT_TEMPERATURE,
334+
max_tokens: 4096,
335+
}),
336+
)
337+
})
338+
339+
it("should handle GLM-4.5-air model correctly", async () => {
340+
const modelId: InternationalZAiModelId = "glm-4.5-air"
341+
const handlerWithGLMAir = new ZAiHandler({
342+
apiModelId: modelId,
343+
zaiApiKey: "test-zai-api-key",
344+
zaiApiLine: "international",
345+
})
346+
347+
mockCreate.mockImplementationOnce(() => {
348+
return {
349+
[Symbol.asyncIterator]: () => ({
350+
async next() {
351+
return { done: true }
352+
},
353+
}),
354+
}
355+
})
356+
357+
const messageGenerator = handlerWithGLMAir.createMessage("system", [])
358+
await messageGenerator.next()
359+
360+
// Should apply GLM enhancements for glm-4.5-air as well
361+
expect(mockCreate).toHaveBeenCalledWith(
362+
expect.objectContaining({
363+
model: modelId,
364+
max_tokens: 32768,
365+
messages: expect.arrayContaining([
366+
{
367+
role: "system",
368+
content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"),
369+
},
370+
]),
227371
}),
228372
)
229373
})

src/api/providers/zai.ts

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,19 @@ import {
77
type MainlandZAiModelId,
88
ZAI_DEFAULT_TEMPERATURE,
99
} from "@roo-code/types"
10+
import { Anthropic } from "@anthropic-ai/sdk"
11+
import OpenAI from "openai"
1012

1113
import type { ApiHandlerOptions } from "../../shared/api"
14+
import { ApiStream } from "../transform/stream"
15+
import { convertToOpenAiMessages } from "../transform/openai-format"
16+
import type { ApiHandlerCreateMessageMetadata } from "../index"
1217

1318
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
1419

1520
export class ZAiHandler extends BaseOpenAiCompatibleProvider<InternationalZAiModelId | MainlandZAiModelId> {
21+
private readonly isGLM45: boolean
22+
1623
constructor(options: ApiHandlerOptions) {
1724
const isChina = options.zaiApiLine === "china"
1825
const models = isChina ? mainlandZAiModels : internationalZAiModels
@@ -27,5 +34,187 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider<InternationalZAiMod
2734
providerModels: models,
2835
defaultTemperature: ZAI_DEFAULT_TEMPERATURE,
2936
})
37+
38+
// Check if the model is GLM-4.5 or GLM-4.5-Air
39+
const modelId = options.apiModelId || defaultModelId
40+
this.isGLM45 = modelId.includes("glm-4.5")
41+
}
42+
43+
/**
44+
* Override createMessage to add GLM-specific handling
45+
*/
46+
override async *createMessage(
47+
systemPrompt: string,
48+
messages: Anthropic.Messages.MessageParam[],
49+
metadata?: ApiHandlerCreateMessageMetadata,
50+
): ApiStream {
51+
// For GLM-4.5 models, enhance the system prompt with clearer instructions
52+
let enhancedSystemPrompt = systemPrompt
53+
54+
if (this.isGLM45) {
55+
// Add GLM-specific instructions to prevent hallucination and improve tool understanding
56+
const glmInstructions = `
57+
58+
# CRITICAL INSTRUCTIONS FOR GLM MODEL
59+
60+
## File and Code Awareness
61+
- NEVER assume or hallucinate files that don't exist. Always verify file existence using the provided tools.
62+
- When exploring code, ALWAYS use the available tools (read_file, list_files, search_files) to examine actual files.
63+
- If you're unsure about a file's existence or location, use list_files to explore the directory structure first.
64+
- Base all code analysis and modifications on actual file contents retrieved through tools, not assumptions.
65+
66+
## Tool Usage Protocol
67+
- Tools are invoked using XML-style tags as shown in the examples.
68+
- Each tool invocation must be properly formatted with the exact tool name as the XML tag.
69+
- Wait for tool execution results before proceeding to the next step.
70+
- Never simulate or imagine tool outputs - always use actual results.
71+
72+
## Content Management
73+
- When working with large files or responses, focus on the specific sections relevant to the task.
74+
- Use partial reads when available to efficiently handle large files.
75+
- Condense and summarize appropriately while maintaining accuracy.
76+
- Keep responses concise and within token limits by focusing on essential information.
77+
78+
## Code Indexing Integration
79+
- The code index provides semantic understanding of the codebase.
80+
- Use codebase_search for initial exploration when available.
81+
- Combine index results with actual file reading for complete understanding.
82+
- Trust the index for finding relevant code patterns and implementations.`
83+
84+
enhancedSystemPrompt = systemPrompt + glmInstructions
85+
}
86+
87+
const {
88+
id: model,
89+
info: { maxTokens: max_tokens },
90+
} = this.getModel()
91+
92+
const temperature = this.options.modelTemperature ?? this.defaultTemperature
93+
94+
// For GLM models, we may need to adjust the max_tokens to leave room for proper responses
95+
// GLM models sometimes struggle with very high token limits
96+
const adjustedMaxTokens = this.isGLM45 && max_tokens ? Math.min(max_tokens, 32768) : max_tokens
97+
98+
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
99+
model,
100+
max_tokens: adjustedMaxTokens || 32768,
101+
temperature,
102+
messages: [
103+
{ role: "system", content: enhancedSystemPrompt },
104+
...this.preprocessMessages(convertToOpenAiMessages(messages)),
105+
],
106+
stream: true,
107+
stream_options: { include_usage: true },
108+
}
109+
110+
// Add additional parameters for GLM models to improve response quality
111+
if (this.isGLM45) {
112+
// GLM models benefit from explicit top_p and frequency_penalty settings
113+
Object.assign(params, {
114+
top_p: 0.95,
115+
frequency_penalty: 0.1,
116+
presence_penalty: 0.1,
117+
})
118+
}
119+
120+
const stream = await this.client.chat.completions.create(params)
121+
122+
for await (const chunk of stream) {
123+
const delta = chunk.choices[0]?.delta
124+
125+
if (delta?.content) {
126+
yield {
127+
type: "text",
128+
text: delta.content,
129+
}
130+
}
131+
132+
if (chunk.usage) {
133+
yield {
134+
type: "usage",
135+
inputTokens: chunk.usage.prompt_tokens || 0,
136+
outputTokens: chunk.usage.completion_tokens || 0,
137+
}
138+
}
139+
}
140+
}
141+
142+
/**
143+
* Preprocess messages for GLM models to ensure better understanding
144+
*/
145+
private preprocessMessages(
146+
messages: OpenAI.Chat.ChatCompletionMessageParam[],
147+
): OpenAI.Chat.ChatCompletionMessageParam[] {
148+
if (!this.isGLM45) {
149+
return messages
150+
}
151+
152+
// For GLM models, ensure tool-related messages are clearly formatted
153+
return messages.map((msg) => {
154+
if (msg.role === "assistant" && typeof msg.content === "string") {
155+
// Ensure XML tags in assistant messages are properly formatted
156+
// GLM models sometimes struggle with complex XML structures
157+
const content = msg.content
158+
.replace(/(<\/?[^>]+>)/g, "\n$1\n") // Add newlines around XML tags
159+
.replace(/\n\n+/g, "\n") // Remove excessive newlines
160+
.trim()
161+
162+
return { ...msg, content }
163+
}
164+
165+
if (msg.role === "user" && Array.isArray(msg.content)) {
166+
// For user messages with multiple content blocks, ensure text is clear
167+
const processedContent = msg.content.map((block: any) => {
168+
if (block.type === "text") {
169+
// Add clear markers for tool results to help GLM understand context
170+
if (block.text.includes("[ERROR]") || block.text.includes("Error:")) {
171+
return {
172+
...block,
173+
text: `[TOOL EXECUTION RESULT - ERROR]\n${block.text}\n[END TOOL RESULT]`,
174+
}
175+
} else if (block.text.includes("Success:") || block.text.includes("successfully")) {
176+
return {
177+
...block,
178+
text: `[TOOL EXECUTION RESULT - SUCCESS]\n${block.text}\n[END TOOL RESULT]`,
179+
}
180+
}
181+
}
182+
return block
183+
})
184+
185+
return { ...msg, content: processedContent }
186+
}
187+
188+
return msg
189+
})
190+
}
191+
192+
/**
193+
* Override completePrompt for better GLM handling
194+
*/
195+
override async completePrompt(prompt: string): Promise<string> {
196+
const { id: modelId } = this.getModel()
197+
198+
try {
199+
// For GLM models, add a clear instruction prefix
200+
const enhancedPrompt = this.isGLM45
201+
? `[INSTRUCTION] Please provide a direct and accurate response based on facts. Do not hallucinate or make assumptions.\n\n${prompt}`
202+
: prompt
203+
204+
const response = await this.client.chat.completions.create({
205+
model: modelId,
206+
messages: [{ role: "user", content: enhancedPrompt }],
207+
temperature: this.defaultTemperature,
208+
max_tokens: 4096,
209+
})
210+
211+
return response.choices[0]?.message.content || ""
212+
} catch (error) {
213+
if (error instanceof Error) {
214+
throw new Error(`${this.providerName} completion error: ${error.message}`)
215+
}
216+
217+
throw error
218+
}
30219
}
31220
}

0 commit comments

Comments
 (0)