Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 147 additions & 3 deletions src/api/providers/__tests__/zai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ describe("ZAiHandler", () => {

it("createMessage should pass correct parameters to Z AI client", async () => {
const modelId: InternationalZAiModelId = "glm-4.5"
const modelInfo = internationalZAiModels[modelId]
const handlerWithModel = new ZAiHandler({
apiModelId: modelId,
zaiApiKey: "test-zai-api-key",
Expand All @@ -216,14 +215,159 @@ describe("ZAiHandler", () => {
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

// For GLM-4.5, expect enhanced system prompt and adjusted parameters
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
max_tokens: 32768, // Adjusted for GLM models
temperature: ZAI_DEFAULT_TEMPERATURE,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
messages: expect.arrayContaining([
{
role: "system",
content: expect.stringContaining(systemPrompt), // Contains original prompt plus enhancements
},
]),
stream: true,
stream_options: { include_usage: true },
top_p: 0.95,
frequency_penalty: 0.1,
presence_penalty: 0.1,
}),
)
})

it("should enhance system prompt for GLM-4.5 models", async () => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage for the GLM-specific enhancements! Consider adding an edge case test for when modelId is undefined to ensure the code handles it gracefully.

const modelId: InternationalZAiModelId = "glm-4.5"
const handlerWithGLM = new ZAiHandler({
apiModelId: modelId,
zaiApiKey: "test-zai-api-key",
zaiApiLine: "international",
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

const messageGenerator = handlerWithGLM.createMessage(systemPrompt, messages)
await messageGenerator.next()

// Check that the system prompt was enhanced with GLM-specific instructions
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
{
role: "system",
content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"),
},
]),
}),
)
})

it("should apply max token adjustment for GLM-4.5 models", async () => {
const modelId: InternationalZAiModelId = "glm-4.5"
const handlerWithGLM = new ZAiHandler({
apiModelId: modelId,
zaiApiKey: "test-zai-api-key",
zaiApiLine: "international",
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithGLM.createMessage("system", [])
await messageGenerator.next()

// Check that max_tokens is capped at 32768 for GLM models
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
max_tokens: 32768,
top_p: 0.95,
frequency_penalty: 0.1,
presence_penalty: 0.1,
}),
)
})

it("should enhance prompt in completePrompt for GLM-4.5 models", async () => {
const modelId: InternationalZAiModelId = "glm-4.5"
const handlerWithGLM = new ZAiHandler({
apiModelId: modelId,
zaiApiKey: "test-zai-api-key",
zaiApiLine: "international",
})

const expectedResponse = "Test response"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })

const testPrompt = "Test prompt"
await handlerWithGLM.completePrompt(testPrompt)

// Check that the prompt was enhanced with GLM-specific prefix
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
messages: [
{
role: "user",
content: expect.stringContaining(
"[INSTRUCTION] Please provide a direct and accurate response",
),
},
],
temperature: ZAI_DEFAULT_TEMPERATURE,
max_tokens: 4096,
}),
)
})

it("should handle GLM-4.5-air model correctly", async () => {
const modelId: InternationalZAiModelId = "glm-4.5-air"
const handlerWithGLMAir = new ZAiHandler({
apiModelId: modelId,
zaiApiKey: "test-zai-api-key",
zaiApiLine: "international",
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const messageGenerator = handlerWithGLMAir.createMessage("system", [])
await messageGenerator.next()

// Should apply GLM enhancements for glm-4.5-air as well
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: 32768,
messages: expect.arrayContaining([
{
role: "system",
content: expect.stringContaining("CRITICAL INSTRUCTIONS FOR GLM MODEL"),
},
]),
}),
)
})
Expand Down
189 changes: 189 additions & 0 deletions src/api/providers/zai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ import {
type MainlandZAiModelId,
ZAI_DEFAULT_TEMPERATURE,
} from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Anthropic import is included but not directly used in the implementation. Is this intentional? The ApiStream return type annotation also seems to be missing from the override declaration. Consider cleaning up unused imports or adding the proper type annotation:

Suggested change
import { Anthropic } from "@anthropic-ai/sdk"
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): AsyncGenerator<ApiStream>

import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import type { ApiHandlerCreateMessageMetadata } from "../index"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class ZAiHandler extends BaseOpenAiCompatibleProvider<InternationalZAiModelId | MainlandZAiModelId> {
private readonly isGLM45: boolean

constructor(options: ApiHandlerOptions) {
const isChina = options.zaiApiLine === "china"
const models = isChina ? mainlandZAiModels : internationalZAiModels
Expand All @@ -27,5 +34,187 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider<InternationalZAiMod
providerModels: models,
defaultTemperature: ZAI_DEFAULT_TEMPERATURE,
})

// Check if the model is GLM-4.5 or GLM-4.5-Air
const modelId = options.apiModelId || defaultModelId
this.isGLM45 = modelId.includes("glm-4.5")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this cause a runtime error if both options.apiModelId and defaultModelId are undefined? Consider adding a null check:

Suggested change
this.isGLM45 = modelId.includes("glm-4.5")
this.isGLM45 = modelId?.includes("glm-4.5") ?? false

}

/**
* Override createMessage to add GLM-specific handling
*/
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
// For GLM-4.5 models, enhance the system prompt with clearer instructions
let enhancedSystemPrompt = systemPrompt

if (this.isGLM45) {
// Add GLM-specific instructions to prevent hallucination and improve tool understanding
const glmInstructions = `

# CRITICAL INSTRUCTIONS FOR GLM MODEL

## File and Code Awareness
- NEVER assume or hallucinate files that don't exist. Always verify file existence using the provided tools.
- When exploring code, ALWAYS use the available tools (read_file, list_files, search_files) to examine actual files.
- If you're unsure about a file's existence or location, use list_files to explore the directory structure first.
- Base all code analysis and modifications on actual file contents retrieved through tools, not assumptions.

## Tool Usage Protocol
- Tools are invoked using XML-style tags as shown in the examples.
- Each tool invocation must be properly formatted with the exact tool name as the XML tag.
- Wait for tool execution results before proceeding to the next step.
- Never simulate or imagine tool outputs - always use actual results.

## Content Management
- When working with large files or responses, focus on the specific sections relevant to the task.
- Use partial reads when available to efficiently handle large files.
- Condense and summarize appropriately while maintaining accuracy.
- Keep responses concise and within token limits by focusing on essential information.

## Code Indexing Integration
- The code index provides semantic understanding of the codebase.
- Use codebase_search for initial exploration when available.
- Combine index results with actual file reading for complete understanding.
- Trust the index for finding relevant code patterns and implementations.`

enhancedSystemPrompt = systemPrompt + glmInstructions
}

const {
id: model,
info: { maxTokens: max_tokens },
} = this.getModel()

const temperature = this.options.modelTemperature ?? this.defaultTemperature

// For GLM models, we may need to adjust the max_tokens to leave room for proper responses
// GLM models sometimes struggle with very high token limits
const adjustedMaxTokens = this.isGLM45 && max_tokens ? Math.min(max_tokens, 32768) : max_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 32768 token limit is hard-coded here and on line 100. Would it make sense to extract this as a constant like GLM_MAX_TOKENS = 32768 for better maintainability?


const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model,
max_tokens: adjustedMaxTokens || 32768,
temperature,
messages: [
{ role: "system", content: enhancedSystemPrompt },
...this.preprocessMessages(convertToOpenAiMessages(messages)),
],
stream: true,
stream_options: { include_usage: true },
}

// Add additional parameters for GLM models to improve response quality
if (this.isGLM45) {
// GLM models benefit from explicit top_p and frequency_penalty settings
Object.assign(params, {
top_p: 0.95,
frequency_penalty: 0.1,
presence_penalty: 0.1,
})
}

const stream = await this.client.chat.completions.create(params)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}
}

/**
* Preprocess messages for GLM models to ensure better understanding
*/
private preprocessMessages(
messages: OpenAI.Chat.ChatCompletionMessageParam[],
): OpenAI.Chat.ChatCompletionMessageParam[] {
if (!this.isGLM45) {
return messages
}

// For GLM models, ensure tool-related messages are clearly formatted
return messages.map((msg) => {
if (msg.role === "assistant" && typeof msg.content === "string") {
// Ensure XML tags in assistant messages are properly formatted
// GLM models sometimes struggle with complex XML structures
const content = msg.content
.replace(/(<\/?[^>]+>)/g, "\n$1\n") // Add newlines around XML tags
.replace(/\n\n+/g, "\n") // Remove excessive newlines
.trim()

return { ...msg, content }
}

if (msg.role === "user" && Array.isArray(msg.content)) {
// For user messages with multiple content blocks, ensure text is clear
const processedContent = msg.content.map((block: any) => {
if (block.type === "text") {
// Add clear markers for tool results to help GLM understand context
if (block.text.includes("[ERROR]") || block.text.includes("Error:")) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This string matching logic for detecting errors/success might miss edge cases. What happens if a message contains both "Error:" and "successfully"? Consider using more robust detection or documenting the precedence rules.

return {
...block,
text: `[TOOL EXECUTION RESULT - ERROR]\n${block.text}\n[END TOOL RESULT]`,
}
} else if (block.text.includes("Success:") || block.text.includes("successfully")) {
return {
...block,
text: `[TOOL EXECUTION RESULT - SUCCESS]\n${block.text}\n[END TOOL RESULT]`,
}
}
}
return block
})

return { ...msg, content: processedContent }
}

return msg
})
}

/**
* Override completePrompt for better GLM handling
*/
override async completePrompt(prompt: string): Promise<string> {
const { id: modelId } = this.getModel()

try {
// For GLM models, add a clear instruction prefix
const enhancedPrompt = this.isGLM45
? `[INSTRUCTION] Please provide a direct and accurate response based on facts. Do not hallucinate or make assumptions.\n\n${prompt}`
: prompt

const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: "user", content: enhancedPrompt }],
temperature: this.defaultTemperature,
max_tokens: 4096,
})

return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`${this.providerName} completion error: ${error.message}`)
}

throw error
}
}
}
Loading