From a71884a1408454008a44f314707b43aec4720c75 Mon Sep 17 00:00:00 2001 From: John Richmond <5629+jr@users.noreply.github.com> Date: Wed, 7 May 2025 16:36:44 -0700 Subject: [PATCH 1/2] Stop leaking other provider settings --- src/core/config/ProviderSettingsManager.ts | 8 +- .../__tests__/ProviderSettingsManager.test.ts | 61 ++++- src/exports/roo-code.d.ts | 6 +- src/exports/types.ts | 6 +- src/schemas/index.ts | 254 +++++++++++++++--- 5 files changed, 286 insertions(+), 49 deletions(-) diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index 53ae585b3c..897ba2386e 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -1,7 +1,7 @@ import { ExtensionContext } from "vscode" import { z, ZodError } from "zod" -import { providerSettingsSchema, ApiConfigMeta } from "../../schemas" +import { providerSettingsSchema, ApiConfigMeta, providerSettingsSchemaDiscriminated } from "../../schemas" import { Mode, modes } from "../../shared/modes" import { telemetryService } from "../../services/telemetry/TelemetryService" @@ -250,7 +250,11 @@ export class ProviderSettingsManager { const providerProfiles = await this.load() // Preserve the existing ID if this is an update to an existing config. const existingId = providerProfiles.apiConfigs[name]?.id - providerProfiles.apiConfigs[name] = { ...config, id: config.id || existingId || this.generateId() } + const id = config.id || existingId || this.generateId() + + // Filter out settings from other providers. + const filteredConfig = providerSettingsSchemaDiscriminated.parse(config) + providerProfiles.apiConfigs[name] = { ...filteredConfig, id } await this.store(providerProfiles) }) } catch (error) { diff --git a/src/core/config/__tests__/ProviderSettingsManager.test.ts b/src/core/config/__tests__/ProviderSettingsManager.test.ts index 064157da02..9399fc45c7 100644 --- a/src/core/config/__tests__/ProviderSettingsManager.test.ts +++ b/src/core/config/__tests__/ProviderSettingsManager.test.ts @@ -247,10 +247,58 @@ describe("ProviderSettingsManager", () => { }, } - expect(mockSecrets.store).toHaveBeenCalledWith( - "roo_cline_config_api_config", - JSON.stringify(expectedConfig, null, 2), + expect(mockSecrets.store.mock.calls[0][0]).toEqual("roo_cline_config_api_config") + expect(storedConfig).toEqual(expectedConfig) + }) + + it("should only save provider relevant settings", async () => { + mockSecrets.get.mockResolvedValue( + JSON.stringify({ + currentApiConfigName: "default", + apiConfigs: { + default: {}, + }, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + }), ) + + const newConfig: ProviderSettings = { + apiProvider: "anthropic", + apiKey: "test-key", + } + const newConfigWithExtra: ProviderSettings = { + ...newConfig, + openRouterApiKey: "another-key", + } + + await providerSettingsManager.saveConfig("test", newConfigWithExtra) + + // Get the actual stored config to check the generated ID + const storedConfig = JSON.parse(mockSecrets.store.mock.lastCall[1]) + const testConfigId = storedConfig.apiConfigs.test.id + + const expectedConfig = { + currentApiConfigName: "default", + apiConfigs: { + default: {}, + test: { + ...newConfig, + id: testConfigId, + }, + }, + modeApiConfigs: { + code: "default", + architect: "default", + ask: "default", + }, + } + + expect(mockSecrets.store.mock.calls[0][0]).toEqual("roo_cline_config_api_config") + expect(storedConfig).toEqual(expectedConfig) }) it("should update existing config", async () => { @@ -291,10 +339,9 @@ describe("ProviderSettingsManager", () => { }, } - expect(mockSecrets.store).toHaveBeenCalledWith( - "roo_cline_config_api_config", - JSON.stringify(expectedConfig, null, 2), - ) + const storedConfig = JSON.parse(mockSecrets.store.mock.lastCall[1]) + expect(mockSecrets.store.mock.lastCall[0]).toEqual("roo_cline_config_api_config") + expect(storedConfig).toEqual(expectedConfig) }) it("should throw error if secrets storage fails", async () => { diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index d0d82fa39e..c9782fce2c 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -121,14 +121,13 @@ type ProviderSettings = { unboundModelId?: string | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined + fakeAi?: unknown | undefined xaiApiKey?: string | undefined groqApiKey?: string | undefined chutesApiKey?: string | undefined litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined - modelMaxTokens?: number | undefined - modelMaxThinkingTokens?: number | undefined includeMaxTokens?: boolean | undefined reasoningEffort?: ("low" | "medium" | "high") | undefined promptCachingDisabled?: boolean | undefined @@ -136,7 +135,8 @@ type ProviderSettings = { fuzzyMatchThreshold?: number | undefined modelTemperature?: (number | null) | undefined rateLimitSeconds?: number | undefined - fakeAi?: unknown | undefined + modelMaxTokens?: number | undefined + modelMaxThinkingTokens?: number | undefined } type GlobalSettings = { diff --git a/src/exports/types.ts b/src/exports/types.ts index 05d492b8cc..9b89f584f6 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -122,14 +122,13 @@ type ProviderSettings = { unboundModelId?: string | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined + fakeAi?: unknown | undefined xaiApiKey?: string | undefined groqApiKey?: string | undefined chutesApiKey?: string | undefined litellmBaseUrl?: string | undefined litellmApiKey?: string | undefined litellmModelId?: string | undefined - modelMaxTokens?: number | undefined - modelMaxThinkingTokens?: number | undefined includeMaxTokens?: boolean | undefined reasoningEffort?: ("low" | "medium" | "high") | undefined promptCachingDisabled?: boolean | undefined @@ -137,7 +136,8 @@ type ProviderSettings = { fuzzyMatchThreshold?: number | undefined modelTemperature?: (number | null) | undefined rateLimitSeconds?: number | undefined - fakeAi?: unknown | undefined + modelMaxTokens?: number | undefined + modelMaxThinkingTokens?: number | undefined } export type { ProviderSettings } diff --git a/src/schemas/index.ts b/src/schemas/index.ts index 094b11b10d..5b7adcfb91 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -345,23 +345,42 @@ type _AssertExperiments = AssertEqual>> * ProviderSettings */ -export const providerSettingsSchema = z.object({ - apiProvider: providerNamesSchema.optional(), - // Anthropic +// Generic settings that apply to all providers +const genericProviderSettingsSchema = z.object({ + includeMaxTokens: z.boolean().optional(), + reasoningEffort: reasoningEffortsSchema.optional(), + promptCachingDisabled: z.boolean().optional(), + diffEnabled: z.boolean().optional(), + fuzzyMatchThreshold: z.number().optional(), + modelTemperature: z.number().nullish(), + rateLimitSeconds: z.number().optional(), + // Claude 3.7 Sonnet Thinking + modelMaxTokens: z.number().optional(), + modelMaxThinkingTokens: z.number().optional(), +}) + +// Provider-specific schemas +const anthropicSchema = z.object({ apiModelId: z.string().optional(), apiKey: z.string().optional(), anthropicBaseUrl: z.string().optional(), anthropicUseAuthToken: z.boolean().optional(), - // Glama +}) + +const glamaSchema = z.object({ glamaModelId: z.string().optional(), glamaApiKey: z.string().optional(), - // OpenRouter +}) + +const openRouterSchema = z.object({ openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), - // Amazon Bedrock +}) + +const bedrockSchema = z.object({ awsAccessKey: z.string().optional(), awsSecretKey: z.string().optional(), awsSessionToken: z.string().optional(), @@ -371,12 +390,16 @@ export const providerSettingsSchema = z.object({ awsProfile: z.string().optional(), awsUseProfile: z.boolean().optional(), awsCustomArn: z.string().optional(), - // Google Vertex +}) + +const vertexSchema = z.object({ vertexKeyFile: z.string().optional(), vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), vertexRegion: z.string().optional(), - // OpenAI +}) + +const openAiSchema = z.object({ openAiBaseUrl: z.string().optional(), openAiApiKey: z.string().optional(), openAiLegacyFormat: z.boolean().optional(), @@ -389,10 +412,14 @@ export const providerSettingsSchema = z.object({ enableReasoningEffort: z.boolean().optional(), openAiHostHeader: z.string().optional(), // Keep temporarily for backward compatibility during migration openAiHeaders: z.record(z.string(), z.string()).optional(), - // Ollama +}) + +const ollamaSchema = z.object({ ollamaModelId: z.string().optional(), ollamaBaseUrl: z.string().optional(), - // VS Code LM +}) + +const vsCodeLmSchema = z.object({ vsCodeLmModelSelector: z .object({ vendor: z.string().optional(), @@ -401,54 +428,213 @@ export const providerSettingsSchema = z.object({ id: z.string().optional(), }) .optional(), - // LM Studio +}) + +const lmStudioSchema = z.object({ lmStudioModelId: z.string().optional(), lmStudioBaseUrl: z.string().optional(), lmStudioDraftModelId: z.string().optional(), lmStudioSpeculativeDecodingEnabled: z.boolean().optional(), - // Gemini +}) + +const geminiSchema = z.object({ geminiApiKey: z.string().optional(), googleGeminiBaseUrl: z.string().optional(), - // OpenAI Native +}) + +const openAiNativeSchema = z.object({ openAiNativeApiKey: z.string().optional(), openAiNativeBaseUrl: z.string().optional(), - // Mistral +}) + +const mistralSchema = z.object({ mistralApiKey: z.string().optional(), mistralCodestralUrl: z.string().optional(), - // DeepSeek +}) + +const deepSeekSchema = z.object({ deepSeekBaseUrl: z.string().optional(), deepSeekApiKey: z.string().optional(), - // Unbound +}) + +const unboundSchema = z.object({ unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), - // Requesty +}) + +const requestySchema = z.object({ requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), - // X.AI (Grok) +}) + +const humanRelaySchema = z.object({}) + +const fakeAiSchema = z.object({ + fakeAi: z.unknown().optional(), +}) + +const xaiSchema = z.object({ xaiApiKey: z.string().optional(), - // Groq +}) + +const groqSchema = z.object({ groqApiKey: z.string().optional(), - // Chutes AI +}) + +const chutesSchema = z.object({ chutesApiKey: z.string().optional(), - // LiteLLM +}) + +const litellmSchema = z.object({ litellmBaseUrl: z.string().optional(), litellmApiKey: z.string().optional(), litellmModelId: z.string().optional(), - // Claude 3.7 Sonnet Thinking - modelMaxTokens: z.number().optional(), - modelMaxThinkingTokens: z.number().optional(), - // Generic - includeMaxTokens: z.boolean().optional(), - reasoningEffort: reasoningEffortsSchema.optional(), - promptCachingDisabled: z.boolean().optional(), - diffEnabled: z.boolean().optional(), - fuzzyMatchThreshold: z.number().optional(), - modelTemperature: z.number().nullish(), - rateLimitSeconds: z.number().optional(), - // Fake AI - fakeAi: z.unknown().optional(), }) +// Default schema for when apiProvider is not specified +const defaultSchema = z.object({ + apiProvider: z.undefined(), +}) + +// Create the discriminated union +export const providerSettingsSchemaDiscriminated = z + .discriminatedUnion("apiProvider", [ + anthropicSchema.merge( + z.object({ + apiProvider: z.literal("anthropic"), + }), + ), + glamaSchema.merge( + z.object({ + apiProvider: z.literal("glama"), + }), + ), + openRouterSchema.merge( + z.object({ + apiProvider: z.literal("openrouter"), + }), + ), + bedrockSchema.merge( + z.object({ + apiProvider: z.literal("bedrock"), + }), + ), + vertexSchema.merge( + z.object({ + apiProvider: z.literal("vertex"), + }), + ), + openAiSchema.merge( + z.object({ + apiProvider: z.literal("openai"), + }), + ), + ollamaSchema.merge( + z.object({ + apiProvider: z.literal("ollama"), + }), + ), + vsCodeLmSchema.merge( + z.object({ + apiProvider: z.literal("vscode-lm"), + }), + ), + lmStudioSchema.merge( + z.object({ + apiProvider: z.literal("lmstudio"), + }), + ), + geminiSchema.merge( + z.object({ + apiProvider: z.literal("gemini"), + }), + ), + openAiNativeSchema.merge( + z.object({ + apiProvider: z.literal("openai-native"), + }), + ), + mistralSchema.merge( + z.object({ + apiProvider: z.literal("mistral"), + }), + ), + deepSeekSchema.merge( + z.object({ + apiProvider: z.literal("deepseek"), + }), + ), + unboundSchema.merge( + z.object({ + apiProvider: z.literal("unbound"), + }), + ), + requestySchema.merge( + z.object({ + apiProvider: z.literal("requesty"), + }), + ), + humanRelaySchema.merge( + z.object({ + apiProvider: z.literal("human-relay"), + }), + ), + fakeAiSchema.merge( + z.object({ + apiProvider: z.literal("fake-ai"), + }), + ), + xaiSchema.merge( + z.object({ + apiProvider: z.literal("xai"), + }), + ), + groqSchema.merge( + z.object({ + apiProvider: z.literal("groq"), + }), + ), + chutesSchema.merge( + z.object({ + apiProvider: z.literal("chutes"), + }), + ), + litellmSchema.merge( + z.object({ + apiProvider: z.literal("litellm"), + }), + ), + defaultSchema, + ]) + .and(genericProviderSettingsSchema) + +export const providerSettingsSchema = z + .object({ + apiProvider: providerNamesSchema.optional(), + }) + .merge(anthropicSchema) + .merge(glamaSchema) + .merge(openRouterSchema) + .merge(bedrockSchema) + .merge(vertexSchema) + .merge(openAiSchema) + .merge(ollamaSchema) + .merge(vsCodeLmSchema) + .merge(lmStudioSchema) + .merge(geminiSchema) + .merge(openAiNativeSchema) + .merge(mistralSchema) + .merge(deepSeekSchema) + .merge(unboundSchema) + .merge(requestySchema) + .merge(humanRelaySchema) + .merge(fakeAiSchema) + .merge(xaiSchema) + .merge(groqSchema) + .merge(chutesSchema) + .merge(litellmSchema) + .merge(genericProviderSettingsSchema) + export type ProviderSettings = z.infer type ProviderSettingsRecord = Record, undefined> From 9a6f783dc7eb20d70a514112c96b6f96c8d9565d Mon Sep 17 00:00:00 2001 From: John Richmond <5629+jr@users.noreply.github.com> Date: Thu, 8 May 2025 14:42:52 -0700 Subject: [PATCH 2/2] Also filter out leaked properties on export --- src/core/config/ProviderSettingsManager.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index 897ba2386e..d0cc1a90da 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -6,6 +6,9 @@ import { Mode, modes } from "../../shared/modes" import { telemetryService } from "../../services/telemetry/TelemetryService" const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() }) +const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and( + z.object({ id: z.string().optional() }), +) type ProviderSettingsWithId = z.infer @@ -385,7 +388,15 @@ export class ProviderSettingsManager { public async export() { try { - return await this.lock(async () => providerProfilesSchema.parse(await this.load())) + return await this.lock(async () => { + const profiles = providerProfilesSchema.parse(await this.load()) + const configs = profiles.apiConfigs + for (const name in configs) { + // Avoid leaking properties from other providers. + configs[name] = discriminatedProviderSettingsWithIdSchema.parse(configs[name]) + } + return profiles + }) } catch (error) { throw new Error(`Failed to export provider profiles: ${error}`) }