diff --git a/src/api/FallbackApiHandler.ts b/src/api/FallbackApiHandler.ts new file mode 100644 index 00000000000..f824dc56ae3 --- /dev/null +++ b/src/api/FallbackApiHandler.ts @@ -0,0 +1,165 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import type { ProviderSettingsWithId, ModelInfo } from "@roo-code/types" +import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "./index" +import { ApiStream, ApiStreamChunk, ApiStreamError } from "./transform/stream" +import { logger } from "../utils/logging" + +/** + * FallbackApiHandler wraps multiple API handlers and automatically falls back + * to the next handler in the chain if the current one fails. + */ +export class FallbackApiHandler implements ApiHandler { + private handlers: ApiHandler[] + private configurations: ProviderSettingsWithId[] + private currentHandlerIndex: number = 0 + private lastSuccessfulIndex: number = 0 + + constructor(configurations: ProviderSettingsWithId[]) { + if (!configurations || configurations.length === 0) { + throw new Error("At least one API configuration is required") + } + + this.configurations = configurations + this.handlers = configurations.map((config) => buildApiHandler(config)) + } + + /** + * Creates a message with automatic fallback to secondary providers if the primary fails. + */ + createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + // Return an async generator that handles fallback logic + return this.createMessageWithFallback(systemPrompt, messages, metadata) + } + + private async *createMessageWithFallback( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + let lastError: Error | undefined + + // Try each handler in sequence until one succeeds + for (let i = 0; i < this.handlers.length; i++) { + this.currentHandlerIndex = i + const handler = this.handlers[i] + const config = this.configurations[i] + + try { + logger.info(`Attempting API call with provider: ${config.apiProvider || "default"} (index ${i})`) + + // Create a stream from the current handler + const stream = handler.createMessage(systemPrompt, messages, metadata) + + // Track if we've successfully received any chunks + let hasReceivedChunks = false + + try { + // Iterate through the stream and yield chunks + for await (const chunk of stream) { + hasReceivedChunks = true + + // Check if this is an error chunk + if (chunk.type === "error") { + // If we've already received some chunks, yield the error + // Otherwise, throw to trigger fallback + if (hasReceivedChunks) { + yield chunk + } else { + throw new Error(chunk.message || chunk.error) + } + } else { + // Yield successful chunks + yield chunk + } + } + + // If we successfully completed the stream, update the last successful index + if (hasReceivedChunks) { + this.lastSuccessfulIndex = i + logger.info(`API call succeeded with provider: ${config.apiProvider || "default"}`) + return // Successfully completed, exit the function + } + } catch (streamError) { + // Stream failed, try the next handler + lastError = streamError as Error + logger.warn( + `API call failed with provider: ${config.apiProvider || "default"} (index ${i}). Error: ${lastError.message}`, + ) + + // If this is not the last handler, continue to the next one + if (i < this.handlers.length - 1) { + logger.info(`Falling back to next provider...`) + continue + } + } + } catch (error) { + lastError = error as Error + logger.warn( + `API call failed with provider: ${config.apiProvider || "default"} (index ${i}). Error: ${lastError.message}`, + ) + + // If this is not the last handler, continue to the next one + if (i < this.handlers.length - 1) { + logger.info(`Falling back to next provider...`) + continue + } + } + } + + // All handlers failed, yield an error chunk + const errorMessage = `All API providers failed. Last error: ${lastError?.message || "Unknown error"}` + logger.error(errorMessage) + + const errorChunk: ApiStreamError = { + type: "error", + error: lastError?.message || "Unknown error", + message: errorMessage, + } + + yield errorChunk + } + + /** + * Returns the model information from the currently active handler. + */ + getModel(): { id: string; info: ModelInfo } { + // Return the model from the last successful handler, or the first one if none have succeeded yet + const index = this.lastSuccessfulIndex + return this.handlers[index].getModel() + } + + /** + * Counts tokens using the currently active handler. + */ + async countTokens(content: Array): Promise { + // Use the last successful handler for token counting + const index = this.lastSuccessfulIndex + return this.handlers[index].countTokens(content) + } + + /** + * Gets the current provider name for logging/debugging purposes. + */ + getCurrentProvider(): string { + return this.configurations[this.currentHandlerIndex]?.apiProvider || "default" + } + + /** + * Gets all configured providers in order. + */ + getConfiguredProviders(): string[] { + return this.configurations.map((config) => config.apiProvider || "default") + } + + /** + * Resets the handler to use the primary provider again. + */ + reset(): void { + this.currentHandlerIndex = 0 + this.lastSuccessfulIndex = 0 + } +} diff --git a/src/api/__tests__/FallbackApiHandler.spec.ts b/src/api/__tests__/FallbackApiHandler.spec.ts new file mode 100644 index 00000000000..ab41913943d --- /dev/null +++ b/src/api/__tests__/FallbackApiHandler.spec.ts @@ -0,0 +1,262 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { Anthropic } from "@anthropic-ai/sdk" +import { FallbackApiHandler } from "../FallbackApiHandler" +import * as apiIndex from "../index" +import type { ProviderSettingsWithId, ModelInfo } from "@roo-code/types" +import { ApiStreamChunk, ApiStreamError } from "../transform/stream" + +// Mock the buildApiHandler function +vi.mock("../index", async () => { + const actual = await vi.importActual("../index") + return { + ...actual, + buildApiHandler: vi.fn(), + } +}) + +describe("FallbackApiHandler", () => { + let mockHandlers: any[] + let configurations: ProviderSettingsWithId[] + + beforeEach(() => { + vi.clearAllMocks() + + // Create mock configurations + configurations = [ + { id: "primary", apiProvider: "anthropic" }, + { id: "secondary", apiProvider: "openai" }, + { id: "tertiary", apiProvider: "ollama" }, + ] + + // Create mock handlers + mockHandlers = configurations.map((config, index) => ({ + createMessage: vi.fn(), + getModel: vi.fn().mockReturnValue({ + id: `model-${index}`, + info: { contextWindow: 100000 } as ModelInfo, + }), + countTokens: vi.fn().mockResolvedValue(100), + })) + + // Setup the mock to return appropriate handlers + vi.mocked(apiIndex.buildApiHandler).mockImplementation((config: any) => { + const index = configurations.findIndex((c) => c.id === config.id) + return mockHandlers[index] || mockHandlers[0] + }) + }) + + describe("constructor", () => { + it("should throw error if no configurations provided", () => { + expect(() => new FallbackApiHandler([])).toThrow("At least one API configuration is required") + }) + + it("should initialize with valid configurations", () => { + const handler = new FallbackApiHandler(configurations) + expect(handler).toBeDefined() + expect(handler.getConfiguredProviders()).toEqual(["anthropic", "openai", "ollama"]) + }) + }) + + describe("createMessage", () => { + it("should use primary handler when it succeeds", async () => { + const handler = new FallbackApiHandler(configurations) + const systemPrompt = "Test prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + // Mock successful stream from primary handler + const mockStream = (async function* () { + yield { type: "text", text: "Response from primary" } as ApiStreamChunk + yield { type: "usage", inputTokens: 10, outputTokens: 20 } as ApiStreamChunk + })() + + mockHandlers[0].createMessage.mockReturnValue(mockStream) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: ApiStreamChunk[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toHaveLength(2) + expect(chunks[0]).toEqual({ type: "text", text: "Response from primary" }) + expect(mockHandlers[0].createMessage).toHaveBeenCalledWith(systemPrompt, messages, undefined) + expect(mockHandlers[1].createMessage).not.toHaveBeenCalled() + expect(mockHandlers[2].createMessage).not.toHaveBeenCalled() + }) + + it("should fallback to secondary handler when primary fails", async () => { + const handler = new FallbackApiHandler(configurations) + const systemPrompt = "Test prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + // Mock failed stream from primary handler + // eslint-disable-next-line require-yield + const mockFailedStream = (async function* () { + throw new Error("Primary API failed") + })() + + // Mock successful stream from secondary handler + const mockSuccessStream = (async function* () { + yield { type: "text", text: "Response from secondary" } as ApiStreamChunk + })() + + mockHandlers[0].createMessage.mockReturnValue(mockFailedStream) + mockHandlers[1].createMessage.mockReturnValue(mockSuccessStream) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: ApiStreamChunk[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toHaveLength(1) + expect(chunks[0]).toEqual({ type: "text", text: "Response from secondary" }) + expect(mockHandlers[0].createMessage).toHaveBeenCalled() + expect(mockHandlers[1].createMessage).toHaveBeenCalled() + expect(mockHandlers[2].createMessage).not.toHaveBeenCalled() + }) + + it("should try all handlers and return error if all fail", async () => { + const handler = new FallbackApiHandler(configurations) + const systemPrompt = "Test prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + // Mock all handlers to fail + mockHandlers.forEach((mockHandler, index) => { + // eslint-disable-next-line require-yield + const mockFailedStream = (async function* () { + throw new Error(`Handler ${index} failed`) + })() + mockHandler.createMessage.mockReturnValue(mockFailedStream) + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: ApiStreamChunk[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toHaveLength(1) + const errorChunk = chunks[0] as ApiStreamError + expect(errorChunk.type).toBe("error") + expect(errorChunk.message).toContain("All API providers failed") + + // All handlers should have been tried + mockHandlers.forEach((mockHandler) => { + expect(mockHandler.createMessage).toHaveBeenCalled() + }) + }) + + it("should handle partial stream failure and fallback", async () => { + const handler = new FallbackApiHandler(configurations) + const systemPrompt = "Test prompt" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + // Mock primary handler to fail after yielding some chunks + const mockPartialFailStream = (async function* () { + yield { type: "text", text: "Partial " } as ApiStreamChunk + throw new Error("Stream interrupted") + })() + + // Mock secondary handler to succeed + const mockSuccessStream = (async function* () { + yield { type: "text", text: "Complete response from secondary" } as ApiStreamChunk + })() + + mockHandlers[0].createMessage.mockReturnValue(mockPartialFailStream) + mockHandlers[1].createMessage.mockReturnValue(mockSuccessStream) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: ApiStreamChunk[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should get partial response from primary, then complete response from secondary + expect(chunks).toHaveLength(2) + expect(chunks[0]).toEqual({ type: "text", text: "Partial " }) + expect(chunks[1]).toEqual({ type: "text", text: "Complete response from secondary" }) + }) + }) + + describe("getModel", () => { + it("should return model from last successful handler", async () => { + const handler = new FallbackApiHandler(configurations) + + // Initially should return from first handler + let model = handler.getModel() + expect(model.id).toBe("model-0") + + // Simulate a successful call with the second handler + const mockStream = (async function* () { + yield { type: "text", text: "Response" } as ApiStreamChunk + })() + + mockHandlers[0].createMessage.mockReturnValue( + // eslint-disable-next-line require-yield + (async function* () { + throw new Error("Failed") + })(), + ) + mockHandlers[1].createMessage.mockReturnValue(mockStream) + + const stream = handler.createMessage("prompt", []) + for await (const _ of stream) { + // Consume stream + } + + // Now should return from second handler + model = handler.getModel() + expect(model.id).toBe("model-1") + }) + }) + + describe("countTokens", () => { + it("should use last successful handler for token counting", async () => { + const handler = new FallbackApiHandler(configurations) + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + mockHandlers[0].countTokens.mockResolvedValue(150) + + const count = await handler.countTokens(content) + expect(count).toBe(150) + expect(mockHandlers[0].countTokens).toHaveBeenCalledWith(content) + }) + }) + + describe("reset", () => { + it("should reset to use primary handler", async () => { + const handler = new FallbackApiHandler(configurations) + + // Simulate using secondary handler + const mockStream = (async function* () { + yield { type: "text", text: "Response" } as ApiStreamChunk + })() + + mockHandlers[0].createMessage.mockReturnValue( + // eslint-disable-next-line require-yield + (async function* () { + throw new Error("Failed") + })(), + ) + mockHandlers[1].createMessage.mockReturnValue(mockStream) + + const stream = handler.createMessage("prompt", []) + for await (const _ of stream) { + // Consume stream + } + + expect(handler.getCurrentProvider()).toBe("openai") + + // Reset + handler.reset() + + // Should be back to primary + expect(handler.getCurrentProvider()).toBe("anthropic") + }) + }) +}) diff --git a/src/api/index.ts b/src/api/index.ts index c29c230b063..d7b812feee4 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -1,8 +1,9 @@ import { Anthropic } from "@anthropic-ai/sdk" -import type { ProviderSettings, ModelInfo } from "@roo-code/types" +import type { ProviderSettings, ProviderSettingsWithId, ModelInfo } from "@roo-code/types" import { ApiStream } from "./transform/stream" +import { FallbackApiHandler } from "./FallbackApiHandler" import { GlamaHandler, @@ -145,3 +146,31 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new AnthropicHandler(options) } } + +/** + * Builds an API handler with fallback support. + * If multiple configurations are provided, returns a FallbackApiHandler that will + * automatically try each configuration in order until one succeeds. + * + * @param configurations - Either a single ProviderSettings or an array of ProviderSettingsWithId for fallback + * @returns An ApiHandler that may support fallback + */ +export function buildApiHandlerWithFallback(configurations: ProviderSettings | ProviderSettingsWithId[]): ApiHandler { + // If it's an array with multiple configurations, use FallbackApiHandler + if (Array.isArray(configurations)) { + if (configurations.length === 0) { + throw new Error("At least one API configuration is required") + } + + if (configurations.length === 1) { + // Single configuration, use regular handler + return buildApiHandler(configurations[0]) + } + + // Multiple configurations, use fallback handler + return new FallbackApiHandler(configurations) + } + + // Single configuration, use regular handler + return buildApiHandler(configurations) +} diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index 1d2e96b9c0c..0546ba0eb91 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -24,6 +24,8 @@ export const providerProfilesSchema = z.object({ currentApiConfigName: z.string(), apiConfigs: z.record(z.string(), providerSettingsWithIdSchema), modeApiConfigs: z.record(z.string(), z.string()).optional(), + // New field for fallback configurations per mode + modeFallbackConfigs: z.record(z.string(), z.array(z.string())).optional(), cloudProfileIds: z.array(z.string()).optional(), migrations: z .object({ @@ -32,6 +34,7 @@ export const providerProfilesSchema = z.object({ openAiHeadersMigrated: z.boolean().optional(), consecutiveMistakeLimitMigrated: z.boolean().optional(), todoListEnabledMigrated: z.boolean().optional(), + fallbackConfigsMigrated: z.boolean().optional(), }) .optional(), }) @@ -46,16 +49,22 @@ export class ProviderSettingsManager { modes.map((mode) => [mode.slug, this.defaultConfigId]), ) + private readonly defaultModeFallbackConfigs: Record = Object.fromEntries( + modes.map((mode) => [mode.slug, []]), + ) + private readonly defaultProviderProfiles: ProviderProfiles = { currentApiConfigName: "default", apiConfigs: { default: { id: this.defaultConfigId } }, modeApiConfigs: this.defaultModeApiConfigs, + modeFallbackConfigs: this.defaultModeFallbackConfigs, migrations: { rateLimitSecondsMigrated: true, // Mark as migrated on fresh installs diffSettingsMigrated: true, // Mark as migrated on fresh installs openAiHeadersMigrated: true, // Mark as migrated on fresh installs consecutiveMistakeLimitMigrated: true, // Mark as migrated on fresh installs todoListEnabledMigrated: true, // Mark as migrated on fresh installs + fallbackConfigsMigrated: true, // Mark as migrated on fresh installs }, } @@ -157,6 +166,13 @@ export class ProviderSettingsManager { isDirty = true } + // Migrate fallback configs if not already migrated + if (!providerProfiles.migrations.fallbackConfigsMigrated) { + await this.migrateFallbackConfigs(providerProfiles) + providerProfiles.migrations.fallbackConfigsMigrated = true + isDirty = true + } + if (isDirty) { await this.store(providerProfiles) } @@ -274,6 +290,17 @@ export class ProviderSettingsManager { } } + private async migrateFallbackConfigs(providerProfiles: ProviderProfiles) { + try { + // Initialize fallback configs if they don't exist + if (!providerProfiles.modeFallbackConfigs) { + providerProfiles.modeFallbackConfigs = Object.fromEntries(modes.map((mode) => [mode.slug, []])) + } + } catch (error) { + console.error(`[MigrateFallbackConfigs] Failed to migrate fallback configs:`, error) + } + } + /** * List all available configs with metadata. */ @@ -448,6 +475,87 @@ export class ProviderSettingsManager { } } + /** + * Set the fallback API configs for a specific mode. + * @param mode The mode to set fallback configs for + * @param configIds Array of config IDs in priority order (primary first) + */ + public async setModeFallbackConfigs(mode: Mode, configIds: string[]) { + try { + return await this.lock(async () => { + const providerProfiles = await this.load() + // Ensure the fallback config map exists + if (!providerProfiles.modeFallbackConfigs) { + providerProfiles.modeFallbackConfigs = {} + } + // Set the fallback config IDs for this mode + providerProfiles.modeFallbackConfigs[mode] = configIds + await this.store(providerProfiles) + }) + } catch (error) { + throw new Error(`Failed to set mode fallback configs: ${error}`) + } + } + + /** + * Get the fallback API config IDs for a specific mode. + * @param mode The mode to get fallback configs for + * @returns Array of config IDs in priority order, or empty array if none configured + */ + public async getModeFallbackConfigs(mode: Mode): Promise { + try { + return await this.lock(async () => { + const { modeFallbackConfigs } = await this.load() + return modeFallbackConfigs?.[mode] || [] + }) + } catch (error) { + throw new Error(`Failed to get mode fallback configs: ${error}`) + } + } + + /** + * Get all API configurations for a mode (primary + fallbacks). + * @param mode The mode to get configs for + * @returns Array of ProviderSettingsWithId in priority order + */ + public async getModeConfigs(mode: Mode): Promise { + try { + return await this.lock(async () => { + const providerProfiles = await this.load() + const configs: ProviderSettingsWithId[] = [] + + // Get primary config + const primaryId = providerProfiles.modeApiConfigs?.[mode] + if (primaryId) { + const primaryConfig = Object.values(providerProfiles.apiConfigs).find( + (config) => config.id === primaryId, + ) + if (primaryConfig) { + configs.push(primaryConfig) + } + } + + // Get fallback configs + const fallbackIds = providerProfiles.modeFallbackConfigs?.[mode] || [] + for (const fallbackId of fallbackIds) { + // Skip if this is the same as primary (avoid duplicates) + if (fallbackId === primaryId) continue + + const fallbackConfig = Object.values(providerProfiles.apiConfigs).find( + (config) => config.id === fallbackId, + ) + if (fallbackConfig) { + configs.push(fallbackConfig) + } + } + + return configs + }) + } catch (error) { + throw new Error(`Failed to get mode configs: ${error}`) + } + } + public async export() { try { return await this.lock(async () => { diff --git a/src/core/config/__tests__/ProviderSettingsManager.spec.ts b/src/core/config/__tests__/ProviderSettingsManager.spec.ts index e95d2b100ba..f95be358a10 100644 --- a/src/core/config/__tests__/ProviderSettingsManager.spec.ts +++ b/src/core/config/__tests__/ProviderSettingsManager.spec.ts @@ -62,12 +62,20 @@ describe("ProviderSettingsManager", () => { }, }, modeApiConfigs: {}, + modeFallbackConfigs: { + architect: [], + code: [], + ask: [], + debug: [], + orchestrator: [], + }, migrations: { rateLimitSecondsMigrated: true, diffSettingsMigrated: true, openAiHeadersMigrated: true, consecutiveMistakeLimitMigrated: true, todoListEnabledMigrated: true, + fallbackConfigsMigrated: true, }, }), ) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 2103dacb274..1fb724e8799 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -35,7 +35,7 @@ import { TelemetryService } from "@roo-code/telemetry" import { CloudService, UnifiedBridgeService } from "@roo-code/cloud" // api -import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api" +import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler, buildApiHandlerWithFallback } from "../../api" import { ApiStream } from "../../api/transform/stream" // shared @@ -302,7 +302,8 @@ export class Task extends EventEmitter implements TaskLike { }) this.apiConfiguration = apiConfiguration - this.api = buildApiHandler(apiConfiguration) + // Check if we should use fallback handler + this.api = this.createApiHandlerForTask(apiConfiguration, provider) this.autoApprovalHandler = new AutoApprovalHandler() this.urlContentFetcher = new UrlContentFetcher(provider.context) @@ -2329,4 +2330,34 @@ export class Task extends EventEmitter implements TaskLike { public get cwd() { return this.workspacePath } + + /** + * Creates an API handler for the task, potentially with fallback support. + * If multiple configurations are available for the current mode, creates a FallbackApiHandler. + */ + private createApiHandlerForTask(primaryConfig: ProviderSettings, provider: ClineProvider): ApiHandler { + // For history items, use the stored mode; for new tasks, we'll use the provider's current mode + const mode = + this._taskMode || + provider + .getState() + .then((state) => state?.mode) + .catch(() => undefined) + + // If we have a mode, try to get fallback configs + if (mode && typeof mode === "string") { + // This is synchronous for history items (mode already set) + // For new tasks, we'll just use the primary config for now + // TODO: Make this async to properly support fallback configs for new tasks + const configs = provider.providerSettingsManager.getModeConfigs(mode).catch(() => []) + + // For now, just use the primary config + // In a future update, we could make the constructor async or + // update the API handler after mode initialization + return buildApiHandler(primaryConfig) + } + + // No mode specified or couldn't determine mode, use single handler + return buildApiHandler(primaryConfig) + } }