From 667a79bb370410969410423463c541e9caca692b Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 11 Aug 2025 17:32:00 +0000 Subject: [PATCH] fix: improve GLM-4.5 model handling to prevent hallucination and enhance tool understanding - Add GLM-specific system prompt enhancements to prevent file hallucination - Include clear instructions for tool usage protocol and content management - Implement message preprocessing for better GLM model understanding - Add token limit adjustments and model-specific parameters for GLM-4.5 - Enhance completePrompt method with instruction prefix for GLM models - Add comprehensive tests for GLM-specific functionality Fixes #6942 --- src/api/providers/__tests__/zai.spec.ts | 150 ++++++++++++++++++- src/api/providers/zai.ts | 189 ++++++++++++++++++++++++ 2 files changed, 336 insertions(+), 3 deletions(-) diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 6b93aaa43b..b1c2a9f213 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -193,7 +193,6 @@ describe("ZAiHandler", () => { it("createMessage should pass correct parameters to Z AI client", async () => { const modelId: InternationalZAiModelId = "glm-4.5" - const modelInfo = internationalZAiModels[modelId] const handlerWithModel = new ZAiHandler({ apiModelId: modelId, zaiApiKey: "test-zai-api-key", @@ -216,14 +215,159 @@ describe("ZAiHandler", () => { const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) await messageGenerator.next() + // For GLM-4.5, expect enhanced system prompt and adjusted parameters expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: modelId, - max_tokens: modelInfo.maxTokens, + max_tokens: 32768, // Adjusted for GLM models temperature: ZAI_DEFAULT_TEMPERATURE, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), + messages: expect.arrayContaining([ + { + role: "system", + content: expect.stringContaining(systemPrompt), // Contains original prompt plus enhancements + }, + ]), stream: true, stream_options: { include_usage: true }, + top_p: 0.95, + frequency_penalty: 0.1, + presence_penalty: 0.1, + }), + ) + }) + + it("should enhance system prompt for GLM-4.5 models", async () => { + const modelId: InternationalZAiModelId = "glm-4.5" + const handlerWithGLM = new ZAiHandler({ + apiModelId: modelId, + zaiApiKey: "test-zai-api-key", + zaiApiLine: "international", + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const systemPrompt = "Test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] + + const messageGenerator = handlerWithGLM.createMessage(systemPrompt, messages) + await messageGenerator.next() + + // Check that the system prompt was enhanced with GLM-specific instructions + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + { + role: "system", + content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"), + }, + ]), + }), + ) + }) + + it("should apply max token adjustment for GLM-4.5 models", async () => { + const modelId: InternationalZAiModelId = "glm-4.5" + const handlerWithGLM = new ZAiHandler({ + apiModelId: modelId, + zaiApiKey: "test-zai-api-key", + zaiApiLine: "international", + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithGLM.createMessage("system", []) + await messageGenerator.next() + + // Check that max_tokens is capped at 32768 for GLM models + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + max_tokens: 32768, + top_p: 0.95, + frequency_penalty: 0.1, + presence_penalty: 0.1, + }), + ) + }) + + it("should enhance prompt in completePrompt for GLM-4.5 models", async () => { + const modelId: InternationalZAiModelId = "glm-4.5" + const handlerWithGLM = new ZAiHandler({ + apiModelId: modelId, + zaiApiKey: "test-zai-api-key", + zaiApiLine: "international", + }) + + const expectedResponse = "Test response" + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) + + const testPrompt = "Test prompt" + await handlerWithGLM.completePrompt(testPrompt) + + // Check that the prompt was enhanced with GLM-specific prefix + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + messages: [ + { + role: "user", + content: expect.stringContaining( + "[INSTRUCTION] Please provide a direct and accurate response", + ), + }, + ], + temperature: ZAI_DEFAULT_TEMPERATURE, + max_tokens: 4096, + }), + ) + }) + + it("should handle GLM-4.5-air model correctly", async () => { + const modelId: InternationalZAiModelId = "glm-4.5-air" + const handlerWithGLMAir = new ZAiHandler({ + apiModelId: modelId, + zaiApiKey: "test-zai-api-key", + zaiApiLine: "international", + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithGLMAir.createMessage("system", []) + await messageGenerator.next() + + // Should apply GLM enhancements for glm-4.5-air as well + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: modelId, + max_tokens: 32768, + messages: expect.arrayContaining([ + { + role: "system", + content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"), + }, + ]), }), ) }) diff --git a/src/api/providers/zai.ts b/src/api/providers/zai.ts index e37e37f01b..c0d836f0cf 100644 --- a/src/api/providers/zai.ts +++ b/src/api/providers/zai.ts @@ -7,12 +7,19 @@ import { type MainlandZAiModelId, ZAI_DEFAULT_TEMPERATURE, } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" +import { ApiStream } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import type { ApiHandlerCreateMessageMetadata } from "../index" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" export class ZAiHandler extends BaseOpenAiCompatibleProvider { + private readonly isGLM45: boolean + constructor(options: ApiHandlerOptions) { const isChina = options.zaiApiLine === "china" const models = isChina ? mainlandZAiModels : internationalZAiModels @@ -27,5 +34,187 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider { + if (msg.role === "assistant" && typeof msg.content === "string") { + // Ensure XML tags in assistant messages are properly formatted + // GLM models sometimes struggle with complex XML structures + const content = msg.content + .replace(/(<\/?[^>]+>)/g, "\n$1\n") // Add newlines around XML tags + .replace(/\n\n+/g, "\n") // Remove excessive newlines + .trim() + + return { ...msg, content } + } + + if (msg.role === "user" && Array.isArray(msg.content)) { + // For user messages with multiple content blocks, ensure text is clear + const processedContent = msg.content.map((block: any) => { + if (block.type === "text") { + // Add clear markers for tool results to help GLM understand context + if (block.text.includes("[ERROR]") || block.text.includes("Error:")) { + return { + ...block, + text: `[TOOL EXECUTION RESULT - ERROR]\n${block.text}\n[END TOOL RESULT]`, + } + } else if (block.text.includes("Success:") || block.text.includes("successfully")) { + return { + ...block, + text: `[TOOL EXECUTION RESULT - SUCCESS]\n${block.text}\n[END TOOL RESULT]`, + } + } + } + return block + }) + + return { ...msg, content: processedContent } + } + + return msg + }) + } + + /** + * Override completePrompt for better GLM handling + */ + override async completePrompt(prompt: string): Promise { + const { id: modelId } = this.getModel() + + try { + // For GLM models, add a clear instruction prefix + const enhancedPrompt = this.isGLM45 + ? `[INSTRUCTION] Please provide a direct and accurate response based on facts. Do not hallucinate or make assumptions.\n\n${prompt}` + : prompt + + const response = await this.client.chat.completions.create({ + model: modelId, + messages: [{ role: "user", content: enhancedPrompt }], + temperature: this.defaultTemperature, + max_tokens: 4096, + }) + + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`${this.providerName} completion error: ${error.message}`) + } + + throw error + } } }