diff --git a/packages/cloud/src/CloudService.ts b/packages/cloud/src/CloudService.ts index 32ea443cd6f..9a32a16fcb1 100644 --- a/packages/cloud/src/CloudService.ts +++ b/packages/cloud/src/CloudService.ts @@ -4,6 +4,7 @@ import type { CloudUserInfo, TelemetryEvent, OrganizationAllowList, + OrganizationSettings, ClineMessage, ShareVisibility, } from "@roo-code/types" @@ -174,6 +175,11 @@ export class CloudService { return this.settingsService!.getAllowList() } + public getOrganizationSettings(): OrganizationSettings | undefined { + this.ensureInitialized() + return this.settingsService!.getSettings() + } + // TelemetryClient public captureEvent(event: TelemetryEvent): void { diff --git a/packages/types/src/__tests__/cloud-schema.test.ts b/packages/types/src/__tests__/cloud-schema.test.ts new file mode 100644 index 00000000000..7c75b247bf9 --- /dev/null +++ b/packages/types/src/__tests__/cloud-schema.test.ts @@ -0,0 +1,75 @@ +import { describe, it, expect } from "vitest" +import { organizationSettingsSchema } from "../cloud.js" + +describe("organizationSettingsSchema", () => { + it("should accept valid organization settings with defaultProviderSettings", () => { + const validSettings = { + version: 1, + defaultSettings: {}, + allowList: { + allowAll: false, + providers: { + anthropic: { + allowAll: true, + models: [], + }, + }, + }, + defaultProviderSettings: { + anthropic: { + apiProvider: "anthropic" as const, + apiKey: "test-key", + apiModelId: "claude-3-5-sonnet-20241022", + }, + openai: { + apiProvider: "openai" as const, + openAiApiKey: "test-key", + openAiModelId: "gpt-4", + }, + }, + } + + const result = organizationSettingsSchema.safeParse(validSettings) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.defaultProviderSettings).toEqual(validSettings.defaultProviderSettings) + } + }) + + it("should accept organization settings without defaultProviderSettings", () => { + const validSettings = { + version: 1, + defaultSettings: {}, + allowList: { + allowAll: true, + providers: {}, + }, + } + + const result = organizationSettingsSchema.safeParse(validSettings) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.defaultProviderSettings).toBeUndefined() + } + }) + + it("should reject invalid provider names in defaultProviderSettings", () => { + const invalidSettings = { + version: 1, + defaultSettings: {}, + allowList: { + allowAll: true, + providers: {}, + }, + defaultProviderSettings: { + "invalid-provider": { + apiProvider: "invalid-provider", + apiKey: "test-key", + }, + }, + } + + const result = organizationSettingsSchema.safeParse(invalidSettings) + expect(result.success).toBe(false) + }) +}) diff --git a/packages/types/src/cloud.ts b/packages/types/src/cloud.ts index 6df7292dd59..76052059a8f 100644 --- a/packages/types/src/cloud.ts +++ b/packages/types/src/cloud.ts @@ -1,6 +1,7 @@ import { z } from "zod" import { globalSettingsSchema } from "./global-settings.js" +import { providerNamesSchema, providerSettingsSchemaDiscriminated } from "./provider-settings.js" /** * CloudUserInfo @@ -110,6 +111,7 @@ export const organizationSettingsSchema = z.object({ cloudSettings: organizationCloudSettingsSchema.optional(), defaultSettings: organizationDefaultSettingsSchema, allowList: organizationAllowListSchema, + defaultProviderSettings: z.record(providerNamesSchema, providerSettingsSchemaDiscriminated).optional(), }) export type OrganizationSettings = z.infer @@ -133,6 +135,7 @@ export const ORGANIZATION_DEFAULT: OrganizationSettings = { }, defaultSettings: {}, allowList: ORGANIZATION_ALLOW_ALL, + defaultProviderSettings: {}, } as const /** diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 905e657b37e..e1b8ca7607d 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -1453,6 +1453,17 @@ export class ClineProvider const currentMode = mode ?? defaultModeSlug const hasSystemPromptOverride = await this.hasFileBasedSystemPromptOverride(currentMode) + // Get organization settings including default provider settings + let organizationDefaultProviderSettings: Record = {} + try { + const orgSettings = await CloudService.instance.getOrganizationSettings() + organizationDefaultProviderSettings = orgSettings?.defaultProviderSettings || {} + } catch (error) { + console.error( + `[getStateToPostToWebview] failed to get organization settings: ${error instanceof Error ? error.message : String(error)}`, + ) + } + return { version: this.context.extension?.packageJSON?.version ?? "", apiConfiguration, @@ -1541,6 +1552,7 @@ export class ClineProvider cloudIsAuthenticated: cloudIsAuthenticated ?? false, sharingEnabled: sharingEnabled ?? false, organizationAllowList, + organizationDefaultProviderSettings, condensingApiConfigId, customCondensingPrompt, codebaseIndexModels: codebaseIndexModels ?? EMBEDDING_MODEL_PROFILES, diff --git a/src/core/webview/__tests__/webviewMessageHandler-orgDefaults.spec.ts b/src/core/webview/__tests__/webviewMessageHandler-orgDefaults.spec.ts new file mode 100644 index 00000000000..9c2af5d6a28 --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler-orgDefaults.spec.ts @@ -0,0 +1,141 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import * as vscode from "vscode" +import { CloudService } from "@roo-code/cloud" +import { webviewMessageHandler } from "../webviewMessageHandler" +import { ClineProvider } from "../ClineProvider" +import { ProviderSettings } from "@roo-code/types" + +// Mock CloudService +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + instance: { + getOrganizationSettings: vi.fn(), + }, + }, +})) + +describe("webviewMessageHandler - Organization Defaults", () => { + let mockProvider: any + let mockMarketplaceManager: any + + beforeEach(() => { + // Reset mocks + vi.clearAllMocks() + + // Create mock provider + mockProvider = { + log: vi.fn(), + upsertProviderProfile: vi.fn(), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ + apiConfiguration: {}, + currentApiConfigName: "test-config", + }), + } + + // Create mock marketplace manager + mockMarketplaceManager = {} + }) + + it("should apply organization default settings when creating a new profile", async () => { + // Mock organization settings with defaults + const orgDefaults = { + anthropic: { + apiProvider: "anthropic" as const, + anthropicApiKey: "org-default-key", + apiModelId: "claude-3-opus-20240229", + temperature: 0.7, + }, + } + + vi.mocked(CloudService.instance.getOrganizationSettings).mockResolvedValue({ + version: 1, + defaultSettings: {}, + allowList: { allowAll: true, providers: {} }, + defaultProviderSettings: orgDefaults, + }) + + // Send upsertApiConfiguration message + const message = { + type: "upsertApiConfiguration" as const, + text: "new-profile", + apiConfiguration: { + apiProvider: "anthropic", + anthropicApiKey: "user-key", // User-provided key should take precedence + // temperature is not provided, so org default should be used + } as ProviderSettings, + } + + await webviewMessageHandler(mockProvider, message, mockMarketplaceManager) + + // Verify that upsertProviderProfile was called with merged settings + expect(mockProvider.upsertProviderProfile).toHaveBeenCalledWith("new-profile", { + apiProvider: "anthropic", + anthropicApiKey: "user-key", // User value takes precedence + apiModelId: "claude-3-opus-20240229", // From org defaults + temperature: 0.7, // From org defaults + }) + }) + + it("should handle missing organization settings gracefully", async () => { + // Mock CloudService to throw an error + vi.mocked(CloudService.instance.getOrganizationSettings).mockRejectedValue(new Error("Not authenticated")) + + // Send upsertApiConfiguration message + const message = { + type: "upsertApiConfiguration" as const, + text: "new-profile", + apiConfiguration: { + apiProvider: "anthropic", + anthropicApiKey: "user-key", + } as ProviderSettings, + } + + await webviewMessageHandler(mockProvider, message, mockMarketplaceManager) + + // Verify that error was logged + expect(mockProvider.log).toHaveBeenCalledWith(expect.stringContaining("Failed to get organization defaults")) + + // Verify that upsertProviderProfile was still called with original settings + expect(mockProvider.upsertProviderProfile).toHaveBeenCalledWith("new-profile", { + apiProvider: "anthropic", + anthropicApiKey: "user-key", + }) + }) + + it("should not apply defaults for a different provider", async () => { + // Mock organization settings with defaults for anthropic + const orgDefaults = { + anthropic: { + apiProvider: "anthropic" as const, + anthropicApiKey: "org-default-key", + apiModelId: "claude-3-opus-20240229", + }, + } + + vi.mocked(CloudService.instance.getOrganizationSettings).mockResolvedValue({ + version: 1, + defaultSettings: {}, + allowList: { allowAll: true, providers: {} }, + defaultProviderSettings: orgDefaults, + }) + + // Send upsertApiConfiguration message for openai provider + const message = { + type: "upsertApiConfiguration" as const, + text: "new-profile", + apiConfiguration: { + apiProvider: "openai", + openAiApiKey: "user-key", + } as ProviderSettings, + } + + await webviewMessageHandler(mockProvider, message, mockMarketplaceManager) + + // Verify that only the user-provided settings were used + expect(mockProvider.upsertProviderProfile).toHaveBeenCalledWith("new-profile", { + apiProvider: "openai", + openAiApiKey: "user-key", + }) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c739c2ade8d..174d5d8cb26 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1474,7 +1474,28 @@ export const webviewMessageHandler = async ( break case "upsertApiConfiguration": if (message.text && message.apiConfiguration) { - await provider.upsertProviderProfile(message.text, message.apiConfiguration) + // Get organization default settings + let organizationDefaults: Partial = {} + try { + const orgSettings = await CloudService.instance.getOrganizationSettings() + const selectedProvider = message.apiConfiguration.apiProvider + if (orgSettings?.defaultProviderSettings && selectedProvider) { + organizationDefaults = orgSettings.defaultProviderSettings[selectedProvider] || {} + } + } catch (error) { + provider.log( + `[upsertApiConfiguration] Failed to get organization defaults: ${error instanceof Error ? error.message : String(error)}`, + ) + } + + // Merge organization defaults with the provided configuration + // User-provided values take precedence over organization defaults + const mergedConfiguration: ProviderSettings = { + ...organizationDefaults, + ...message.apiConfiguration, + } + + await provider.upsertProviderProfile(message.text, mergedConfiguration) } break case "renameApiConfiguration": diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 000762e317a..89fc3dc89dd 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -10,6 +10,7 @@ import type { OrganizationAllowList, CloudUserInfo, ShareVisibility, + ProviderName, } from "@roo-code/types" import { GitCommit } from "../utils/git" @@ -302,6 +303,7 @@ export type ExtensionState = Pick< cloudApiUrl?: string sharingEnabled: boolean organizationAllowList: OrganizationAllowList + organizationDefaultProviderSettings?: Partial> autoCondenseContext: boolean autoCondenseContextPercent: number diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 38d2ceebd37..308e8fe31d3 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -106,7 +106,7 @@ const ApiOptions = ({ setErrorMessage, }: ApiOptionsProps) => { const { t } = useAppTranslation() - const { organizationAllowList } = useExtensionState() + const { organizationAllowList, organizationDefaultProviderSettings } = useExtensionState() const [customHeaders, setCustomHeaders] = useState<[string, string][]>(() => { const headers = apiConfiguration?.openAiHeaders || {} @@ -246,6 +246,23 @@ const ApiOptions = ({ (value: ProviderName) => { setApiConfigurationField("apiProvider", value) + // Apply organization default settings if available + const orgDefaults = organizationDefaultProviderSettings?.[value] + + if (orgDefaults) { + // Apply each default setting from the organization + Object.entries(orgDefaults).forEach(([key, defaultValue]) => { + // Skip apiProvider as we've already set it + if (key === "apiProvider") return + + // Only apply defaults if the current value is undefined or empty + const currentValue = apiConfiguration[key as keyof ProviderSettings] + if (!currentValue || (typeof currentValue === "string" && currentValue.trim() === "")) { + setApiConfigurationField(key as keyof ProviderSettings, defaultValue) + } + }) + } + // It would be much easier to have a single attribute that stores // the modelId, but we have a separate attribute for each of // OpenRouter, Glama, Unbound, and Requesty. @@ -311,7 +328,7 @@ const ApiOptions = ({ ) } }, - [setApiConfigurationField, apiConfiguration], + [setApiConfigurationField, apiConfiguration, organizationDefaultProviderSettings], ) const modelValidationError = useMemo(() => { diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index ff1ce31c53c..31a3c71e4c3 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -8,6 +8,7 @@ import { type ExperimentId, type OrganizationAllowList, ORGANIZATION_ALLOW_ALL, + type ProviderName, } from "@roo-code/types" import { ExtensionMessage, ExtensionState, MarketplaceInstalledMetadata } from "@roo/ExtensionMessage" @@ -34,6 +35,7 @@ export interface ExtensionStateContextType extends ExtensionState { filePaths: string[] openedTabs: Array<{ label: string; isActive: boolean; path?: string }> organizationAllowList: OrganizationAllowList + organizationDefaultProviderSettings: Partial> cloudIsAuthenticated: boolean sharingEnabled: boolean maxConcurrentFileReads?: number @@ -219,6 +221,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode cloudIsAuthenticated: false, sharingEnabled: false, organizationAllowList: ORGANIZATION_ALLOW_ALL, + organizationDefaultProviderSettings: {}, autoCondenseContext: true, autoCondenseContextPercent: 100, profileThresholds: {}, @@ -379,6 +382,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode screenshotQuality: state.screenshotQuality, routerModels: extensionRouterModels, cloudIsAuthenticated: state.cloudIsAuthenticated ?? false, + organizationDefaultProviderSettings: state.organizationDefaultProviderSettings ?? {}, marketplaceItems, marketplaceInstalledMetadata, profileThresholds: state.profileThresholds ?? {},