Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ export const SECRET_STATE_KEYS = [
"geminiApiKey",
"openAiNativeApiKey",
"deepSeekApiKey",
"moonshotApiKey",
"mistralApiKey",
"unboundApiKey",
"requestyApiKey",
Expand Down
10 changes: 10 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export const providerNames = [
"gemini-cli",
"openai-native",
"mistral",
"moonshot",
"deepseek",
"unbound",
"requesty",
Expand Down Expand Up @@ -186,6 +187,13 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
deepSeekApiKey: z.string().optional(),
})

const moonshotSchema = apiModelIdProviderModelSchema.extend({
moonshotBaseUrl: z
.union([z.literal("https://api.moonshot.ai/v1"), z.literal("https://api.moonshot.cn/v1")])
.optional(),
moonshotApiKey: z.string().optional(),
})

const unboundSchema = baseProviderSettingsSchema.extend({
unboundApiKey: z.string().optional(),
unboundModelId: z.string().optional(),
Expand Down Expand Up @@ -240,6 +248,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })),
humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })),
Expand Down Expand Up @@ -268,6 +277,7 @@ export const providerSettingsSchema = z.object({
...openAiNativeSchema.shape,
...mistralSchema.shape,
...deepSeekSchema.shape,
...moonshotSchema.shape,
...unboundSchema.shape,
...requestySchema.shape,
...humanRelaySchema.shape,
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export * from "./groq.js"
export * from "./lite-llm.js"
export * from "./lm-studio.js"
export * from "./mistral.js"
export * from "./moonshot.js"
export * from "./ollama.js"
export * from "./openai.js"
export * from "./openrouter.js"
Expand Down
22 changes: 22 additions & 0 deletions packages/types/src/providers/moonshot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { ModelInfo } from "../model.js"

// https://platform.moonshot.ai/
export type MoonshotModelId = keyof typeof moonshotModels

export const moonshotDefaultModelId: MoonshotModelId = "kimi-k2-0711-preview"

export const moonshotModels = {
"kimi-k2-0711-preview": {
maxTokens: 32_000,
contextWindow: 131_072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.6, // $0.60 per million tokens (cache miss)
outputPrice: 2.5, // $2.50 per million tokens
cacheWritesPrice: 0, // $0 per million tokens (cache miss)
cacheReadsPrice: 0.15, // $0.15 per million tokens (cache hit)
description: `Kimi K2 is a state-of-the-art mixture-of-experts (MoE) language model with 32 billion activated parameters and 1 trillion total parameters.`,
},
} as const satisfies Record<string, ModelInfo>

export const MOONSHOT_DEFAULT_TEMPERATURE = 0.6
4 changes: 4 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
GeminiHandler,
OpenAiNativeHandler,
DeepSeekHandler,
MoonshotHandler,
MistralHandler,
VsCodeLmHandler,
UnboundHandler,
Expand Down Expand Up @@ -89,6 +90,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new OpenAiNativeHandler(options)
case "deepseek":
return new DeepSeekHandler(options)
case "moonshot":
return new MoonshotHandler(options)
case "vscode-lm":
return new VsCodeLmHandler(options)
case "mistral":
Expand All @@ -110,6 +113,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
case "litellm":
return new LiteLLMHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
}
}
297 changes: 297 additions & 0 deletions src/api/providers/__tests__/moonshot.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
// Mocks must come first, before imports
const mockCreate = vi.fn()
vi.mock("openai", () => {
return {
__esModule: true,
default: vi.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
cached_tokens: 2,
},
}
}

// Return async iterator for streaming
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
cached_tokens: 2,
},
}
},
}
}),
},
},
})),
}
})

import OpenAI from "openai"
import type { Anthropic } from "@anthropic-ai/sdk"

import { moonshotDefaultModelId } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../../shared/api"

import { MoonshotHandler } from "../moonshot"

describe("MoonshotHandler", () => {
let handler: MoonshotHandler
let mockOptions: ApiHandlerOptions

beforeEach(() => {
mockOptions = {
moonshotApiKey: "test-api-key",
apiModelId: "moonshot-chat",
moonshotBaseUrl: "https://api.moonshot.ai/v1",
}
handler = new MoonshotHandler(mockOptions)
vi.clearAllMocks()
})

describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(MoonshotHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})

it.skip("should throw error if API key is missing", () => {
expect(() => {
new MoonshotHandler({
...mockOptions,
moonshotApiKey: undefined,
})
}).toThrow("Moonshot API key is required")
})

it("should use default model ID if not provided", () => {
const handlerWithoutModel = new MoonshotHandler({
...mockOptions,
apiModelId: undefined,
})
expect(handlerWithoutModel.getModel().id).toBe(moonshotDefaultModelId)
})

it("should use default base URL if not provided", () => {
const handlerWithoutBaseUrl = new MoonshotHandler({
...mockOptions,
moonshotBaseUrl: undefined,
})
expect(handlerWithoutBaseUrl).toBeInstanceOf(MoonshotHandler)
// The base URL is passed to OpenAI client internally
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://api.moonshot.ai/v1",
}),
)
})

it("should use chinese base URL if provided", () => {
const customBaseUrl = "https://api.moonshot.cn/v1"
const handlerWithCustomUrl = new MoonshotHandler({
...mockOptions,
moonshotBaseUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(MoonshotHandler)
// The custom base URL is passed to OpenAI client
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
}),
)
})

it("should set includeMaxTokens to true", () => {
// Create a new handler and verify OpenAI client was called with includeMaxTokens
const _handler = new MoonshotHandler(mockOptions)
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: mockOptions.moonshotApiKey }))
})
})

describe("getModel", () => {
it("should return model info for valid model ID", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(32_000)
expect(model.info.contextWindow).toBe(131_072)
expect(model.info.supportsImages).toBe(false)
expect(model.info.supportsPromptCache).toBe(true) // Should be true now
})

it("should return provided model ID with default model info if model does not exist", () => {
const handlerWithInvalidModel = new MoonshotHandler({
...mockOptions,
apiModelId: "invalid-model",
})
const model = handlerWithInvalidModel.getModel()
expect(model.id).toBe("invalid-model") // Returns provided ID
expect(model.info).toBeDefined()
// With the current implementation, it's the same object reference when using default model info
expect(model.info).toBe(handler.getModel().info)
// Should have the same base properties
expect(model.info.contextWindow).toBe(handler.getModel().info.contextWindow)
// And should have supportsPromptCache set to true
expect(model.info.supportsPromptCache).toBe(true)
})

it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new MoonshotHandler({
...mockOptions,
apiModelId: undefined,
})
const model = handlerWithoutModel.getModel()
expect(model.id).toBe(moonshotDefaultModelId)
expect(model.info).toBeDefined()
expect(model.info.supportsPromptCache).toBe(true)
})

it("should include model parameters from getModelParams", () => {
const model = handler.getModel()
expect(model).toHaveProperty("temperature")
expect(model).toHaveProperty("maxTokens")
})
})

describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello!",
},
],
},
]

it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})

it("should include usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].inputTokens).toBe(10)
expect(usageChunks[0].outputTokens).toBe(5)
})

it("should include cache metrics in usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].cacheWriteTokens).toBe(0)
expect(usageChunks[0].cacheReadTokens).toBe(2)
})
})

describe("processUsageMetrics", () => {
it("should correctly process usage metrics including cache information", () => {
// We need to access the protected method, so we'll create a test subclass
class TestMoonshotHandler extends MoonshotHandler {
public testProcessUsageMetrics(usage: any) {
return this.processUsageMetrics(usage)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)

const usage = {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
cached_tokens: 20,
}

const result = testHandler.testProcessUsageMetrics(usage)

expect(result.type).toBe("usage")
expect(result.inputTokens).toBe(100)
expect(result.outputTokens).toBe(50)
expect(result.cacheWriteTokens).toBe(0)
expect(result.cacheReadTokens).toBe(20)
})

it("should handle missing cache metrics gracefully", () => {
class TestMoonshotHandler extends MoonshotHandler {
public testProcessUsageMetrics(usage: any) {
return this.processUsageMetrics(usage)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)

const usage = {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
// No cached_tokens
}

const result = testHandler.testProcessUsageMetrics(usage)

expect(result.type).toBe("usage")
expect(result.inputTokens).toBe(100)
expect(result.outputTokens).toBe(50)
expect(result.cacheWriteTokens).toBe(0)
expect(result.cacheReadTokens).toBeUndefined()
})
})
})
Loading