diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index d1e25358fab0..597ee1797be1 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -8,27 +8,35 @@ vitest.mock("vscode", () => ({ }, })) -import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" import { MiniMaxHandler } from "../minimax" -vitest.mock("openai", () => { +vitest.mock("@anthropic-ai/sdk", () => { const createMock = vitest.fn() + const countTokensMock = vitest.fn() return { - default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })), + Anthropic: vitest.fn(() => ({ + messages: { + create: createMock, + countTokens: countTokensMock, + }, + })), } }) describe("MiniMaxHandler", () => { let handler: MiniMaxHandler let mockCreate: any + let mockCountTokens: any beforeEach(() => { vitest.clearAllMocks() - mockCreate = (OpenAI as unknown as any)().chat.completions.create + const mockClient = (Anthropic as unknown as any)() + mockCreate = mockClient.messages.create + mockCountTokens = mockClient.messages.countTokens }) describe("International MiniMax (default)", () => { @@ -41,7 +49,7 @@ describe("MiniMaxHandler", () => { it("should use the correct international MiniMax base URL by default", () => { new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - expect(OpenAI).toHaveBeenCalledWith( + expect(Anthropic).toHaveBeenCalledWith( expect.objectContaining({ baseURL: "https://api.minimax.io/v1", }), @@ -51,7 +59,7 @@ describe("MiniMaxHandler", () => { it("should use the provided API key", () => { const minimaxApiKey = "test-minimax-api-key" new MiniMaxHandler({ minimaxApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) + expect(Anthropic).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) }) it("should return default model when no model is specified", () => { @@ -117,13 +125,13 @@ describe("MiniMaxHandler", () => { minimaxApiKey: "test-minimax-api-key", minimaxBaseUrl: "https://api.minimaxi.com/v1", }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.minimaxi.com/v1" })) + expect(Anthropic).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.minimaxi.com/v1" })) }) it("should use the provided API key for China", () => { const minimaxApiKey = "test-minimax-api-key" new MiniMaxHandler({ minimaxApiKey, minimaxBaseUrl: "https://api.minimaxi.com/v1" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) + expect(Anthropic).toHaveBeenCalledWith(expect.objectContaining({ apiKey: minimaxApiKey })) }) it("should return default model when no model is specified", () => { @@ -136,7 +144,7 @@ describe("MiniMaxHandler", () => { describe("Default behavior", () => { it("should default to international base URL when none is specified", () => { const handlerDefault = new MiniMaxHandler({ minimaxApiKey: "test-minimax-api-key" }) - expect(OpenAI).toHaveBeenCalledWith( + expect(Anthropic).toHaveBeenCalledWith( expect.objectContaining({ baseURL: "https://api.minimax.io/v1", }), @@ -152,6 +160,10 @@ describe("MiniMaxHandler", () => { const model = handlerDefault.getModel() expect(model.id).toBe("MiniMax-M2") }) + + it("should throw error when API key is not provided", () => { + expect(() => new MiniMaxHandler({} as any)).toThrow("MiniMax API key is required") + }) }) describe("API Methods", () => { @@ -161,15 +173,24 @@ describe("MiniMaxHandler", () => { it("completePrompt method should return text from MiniMax API", async () => { const expectedResponse = "This is a test response from MiniMax" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: expectedResponse }], + }) const result = await handler.completePrompt("test prompt") expect(result).toBe(expectedResponse) + expect(mockCreate).toHaveBeenCalledWith({ + model: "MiniMax-M2", + max_tokens: 16384, + temperature: 1.0, + messages: [{ role: "user", content: "test prompt" }], + stream: false, + }) }) it("should handle errors in completePrompt", async () => { const errorMessage = "MiniMax API error" mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow() + await expect(handler.completePrompt("test prompt")).rejects.toThrow(errorMessage) }) it("createMessage should yield text content from stream", async () => { @@ -182,7 +203,38 @@ describe("MiniMaxHandler", () => { .fn() .mockResolvedValueOnce({ done: false, - value: { choices: [{ delta: { content: testContent } }] }, + value: { + type: "content_block_start", + index: 0, + content_block: { type: "text", text: testContent }, + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ type: "text", text: testContent }) + }) + + it("createMessage should handle text delta chunks", async () => { + const testContent = "streaming text" + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vitest + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + type: "content_block_delta", + delta: { type: "text_delta", text: testContent }, + }, }) .mockResolvedValueOnce({ done: true }), }), @@ -205,8 +257,15 @@ describe("MiniMaxHandler", () => { .mockResolvedValueOnce({ done: false, value: { - choices: [{ delta: {} }], - usage: { prompt_tokens: 10, completion_tokens: 20 }, + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 20, + cache_creation_input_tokens: 5, + cache_read_input_tokens: 3, + }, + }, }, }) .mockResolvedValueOnce({ done: true }), @@ -218,12 +277,17 @@ describe("MiniMaxHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + cacheWriteTokens: 5, + cacheReadTokens: 3, + }) }) - it("createMessage should pass correct parameters to MiniMax client", async () => { + it("createMessage should pass correct parameters to MiniMax client with prompt caching", async () => { const modelId: MinimaxModelId = "MiniMax-M2" - const modelInfo = minimaxModels[modelId] const handlerWithModel = new MiniMaxHandler({ apiModelId: modelId, minimaxApiKey: "test-minimax-api-key", @@ -248,17 +312,47 @@ describe("MiniMaxHandler", () => { expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: modelId, - max_tokens: Math.min(modelInfo.maxTokens, Math.ceil(modelInfo.contextWindow * 0.2)), - temperature: 1, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), + max_tokens: expect.any(Number), + temperature: 1.0, + system: [{ text: systemPrompt, type: "text", cache_control: { type: "ephemeral" } }], + messages: expect.any(Array), stream: true, - stream_options: { include_usage: true }, }), - undefined, + expect.objectContaining({ + headers: { "anthropic-beta": "prompt-caching-2024-07-31" }, + }), ) }) - it("should use temperature 1 by default", async () => { + it("createMessage should handle reasoning/thinking blocks", async () => { + const testThinking = "Let me think about this..." + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vitest + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + type: "content_block_start", + index: 0, + content_block: { type: "thinking", thinking: testThinking }, + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ type: "reasoning", text: testThinking }) + }) + + it("should use temperature 1.0 by default", async () => { mockCreate.mockImplementationOnce(() => { return { [Symbol.asyncIterator]: () => ({ @@ -274,11 +368,59 @@ describe("MiniMaxHandler", () => { expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - temperature: 1, + temperature: 1.0, + }), + expect.any(Object), + ) + }) + + it("should use custom temperature when provided", async () => { + const customTemperature = 0.7 + const handlerWithTemp = new MiniMaxHandler({ + minimaxApiKey: "test-minimax-api-key", + modelTemperature: customTemperature, + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const messageGenerator = handlerWithTemp.createMessage("test", []) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: customTemperature, }), - undefined, + expect.any(Object), ) }) + + it("countTokens should try API first then fallback to tiktoken", async () => { + const content = [{ type: "text", text: "test content" }] as Anthropic.Messages.ContentBlockParam[] + + // First test successful API response + mockCountTokens.mockResolvedValueOnce({ input_tokens: 42 }) + let result = await handler.countTokens(content) + expect(result).toBe(42) + expect(mockCountTokens).toHaveBeenCalledWith({ + model: "MiniMax-M2", + messages: [{ role: "user", content }], + }) + + // Then test API failure with fallback + mockCountTokens.mockRejectedValueOnce(new Error("Not supported")) + result = await handler.countTokens(content) + // Should return a number (tiktoken estimate), exact value depends on tokenizer + expect(typeof result).toBe("number") + expect(result).toBeGreaterThan(0) + }) }) describe("Model Configuration", () => { diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 8a8e8c14e5b4..4056a3966709 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -1,19 +1,260 @@ -import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" +import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources" + +import { + type ModelInfo, + type MinimaxModelId, + minimaxDefaultModelId, + minimaxModels, + MINIMAX_DEFAULT_TEMPERATURE, +} from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { ApiStream } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +export class MiniMaxHandler extends BaseProvider implements SingleCompletionHandler { + private options: ApiHandlerOptions + private client: Anthropic -export class MiniMaxHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "MiniMax", - baseURL: options.minimaxBaseUrl ?? "https://api.minimax.io/v1", - apiKey: options.minimaxApiKey, - defaultProviderModelId: minimaxDefaultModelId, - providerModels: minimaxModels, - defaultTemperature: 1.0, + super() + this.options = options + + if (!this.options.minimaxApiKey) { + throw new Error("MiniMax API key is required") + } + + // MiniMax supports Anthropic-compatible API + // https://platform.minimax.io/docs/api-reference/text-anthropic-api + this.client = new Anthropic({ + baseURL: this.options.minimaxBaseUrl || "https://api.minimax.io/v1", + apiKey: this.options.minimaxApiKey, + }) + } + + async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: modelId, info, maxTokens, temperature } = this.getModel() + const cacheControl: CacheControlEphemeral = { type: "ephemeral" } + + // Check if the model supports prompt caching + const supportsPromptCache = info.supportsPromptCache ?? false + const betas: string[] = [] + + if (supportsPromptCache) { + betas.push("prompt-caching-2024-07-31") + } + + let stream: AnthropicStream + + if (supportsPromptCache) { + // With prompt caching support, handle cache control for user messages + const userMsgIndices = messages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + stream = await this.client.messages.create( + { + model: modelId, + max_tokens: maxTokens ?? 16384, + temperature, + // Setting cache breakpoint for system prompt so new tasks can reuse it + system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }], + messages: messages.map((message, index) => { + if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { + return { + ...message, + content: + typeof message.content === "string" + ? [{ type: "text", text: message.content, cache_control: cacheControl }] + : message.content.map((content, contentIndex) => + contentIndex === message.content.length - 1 + ? { ...content, cache_control: cacheControl } + : content, + ), + } + } + return message + }), + stream: true, + }, + betas.length > 0 ? { headers: { "anthropic-beta": betas.join(",") } } : undefined, + ) + } else { + // Without prompt caching + stream = await this.client.messages.create({ + model: modelId, + max_tokens: maxTokens ?? 16384, + temperature, + system: [{ text: systemPrompt, type: "text" }], + messages, + stream: true, + }) + } + + let inputTokens = 0 + let outputTokens = 0 + let cacheWriteTokens = 0 + let cacheReadTokens = 0 + + for await (const chunk of stream) { + switch (chunk.type) { + case "message_start": { + // Tells us cache reads/writes/input/output + const { + input_tokens = 0, + output_tokens = 0, + cache_creation_input_tokens, + cache_read_input_tokens, + } = chunk.message.usage + + yield { + type: "usage", + inputTokens: input_tokens, + outputTokens: output_tokens, + cacheWriteTokens: cache_creation_input_tokens || undefined, + cacheReadTokens: cache_read_input_tokens || undefined, + } + + inputTokens += input_tokens + outputTokens += output_tokens + cacheWriteTokens += cache_creation_input_tokens || 0 + cacheReadTokens += cache_read_input_tokens || 0 + + break + } + case "message_delta": + // Tells us output tokens along the way and at the end of the message + yield { + type: "usage", + inputTokens: 0, + outputTokens: chunk.usage.output_tokens || 0, + } + break + case "message_stop": + // No usage data, just an indicator that the message is done + break + case "content_block_start": + switch (chunk.content_block.type) { + case "thinking": + // Handle reasoning/thinking blocks if supported + if (chunk.index > 0) { + yield { type: "reasoning", text: "\n" } + } + yield { type: "reasoning", text: chunk.content_block.thinking } + break + case "text": + // We may receive multiple text blocks + if (chunk.index > 0) { + yield { type: "text", text: "\n" } + } + yield { type: "text", text: chunk.content_block.text } + break + } + break + case "content_block_delta": + switch (chunk.delta.type) { + case "thinking_delta": + yield { type: "reasoning", text: chunk.delta.thinking } + break + case "text_delta": + yield { type: "text", text: chunk.delta.text } + break + } + break + case "content_block_stop": + break + } + } + + // Calculate total cost if we have usage data + if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) { + // MiniMax pricing (per million tokens): + // Input: $0.3, Output: $1.2, Cache writes: $0.375, Cache reads: $0.03 + const inputCost = (inputTokens / 1_000_000) * (info.inputPrice || 0) + const outputCost = (outputTokens / 1_000_000) * (info.outputPrice || 0) + const cacheWriteCost = (cacheWriteTokens / 1_000_000) * (info.cacheWritesPrice || 0) + const cacheReadCost = (cacheReadTokens / 1_000_000) * (info.cacheReadsPrice || 0) + const totalCost = inputCost + outputCost + cacheWriteCost + cacheReadCost + + yield { + type: "usage", + inputTokens: 0, + outputTokens: 0, + totalCost, + } + } + } + + getModel() { + const modelId = this.options.apiModelId + const id = modelId && modelId in minimaxModels ? (modelId as MinimaxModelId) : minimaxDefaultModelId + const info: ModelInfo = minimaxModels[id] + + const params = getModelParams({ + format: "anthropic", + modelId: id, + model: info, + settings: this.options, }) + + return { + id, + info, + ...params, + temperature: this.options.modelTemperature ?? MINIMAX_DEFAULT_TEMPERATURE, + } + } + + async completePrompt(prompt: string) { + const { id: model, temperature } = this.getModel() + + const message = await this.client.messages.create({ + model, + max_tokens: 16384, + temperature, + messages: [{ role: "user", content: prompt }], + stream: false, + }) + + const content = message.content.find(({ type }) => type === "text") + return content?.type === "text" ? content.text : "" + } + + /** + * Counts tokens for the given content using MiniMax's Anthropic-compatible API + * Falls back to tiktoken estimation if the API doesn't support token counting + */ + override async countTokens(content: Array): Promise { + try { + // Try to use the API's token counting if available + // Note: This might not be supported by MiniMax yet + const { id: model } = this.getModel() + + // MiniMax might not have token counting endpoint yet + // If they add it, it would follow Anthropic's pattern + const response = await this.client.messages.countTokens({ + model, + messages: [{ role: "user", content: content }], + }) + + return response.input_tokens + } catch (error) { + // Fallback to tiktoken estimation + return super.countTokens(content) + } } }