diff --git a/src/constants.context.ts b/src/constants.context.ts index db977c21aa6e1..d495f55b62b15 100644 --- a/src/constants.context.ts +++ b/src/constants.context.ts @@ -3,6 +3,7 @@ import type { AnnotationStatus, Keys } from './constants'; import type { SubscriptionState } from './constants.subscription'; import type { CustomEditorTypes, GroupableTreeViewTypes, WebviewTypes, WebviewViewTypes } from './constants.views'; import type { Features } from './features'; +import type { OrgAIProviders } from './plus/gk/models/organization'; import type { PromoKeys } from './plus/gk/models/promo'; import type { SubscriptionPlanIds } from './plus/gk/models/subscription'; import type { WalkthroughContextKeys } from './telemetry/walkthroughStateProvider'; @@ -14,6 +15,8 @@ export type ContextKeys = { 'gitlens:enabled': boolean; 'gitlens:gk:hasOrganizations': boolean; 'gitlens:gk:organization:ai:enabled': boolean; + 'gitlens:gk:organization:ai:enforceProviders': boolean; + 'gitlens:gk:organization:ai:providers': OrgAIProviders; 'gitlens:gk:organization:drafts:byob': boolean; 'gitlens:gk:organization:drafts:enabled': boolean; 'gitlens:hasVirtualFolders': boolean; diff --git a/src/constants.storage.ts b/src/constants.storage.ts index 24e86f77986ad..a682a5dfa34ea 100644 --- a/src/constants.storage.ts +++ b/src/constants.storage.ts @@ -7,6 +7,7 @@ import type { GroupableTreeViewTypes } from './constants.views'; import type { Environment } from './container'; import type { FeaturePreviews } from './features'; import type { GitRevisionRangeNotation } from './git/models/revision'; +import type { OrganizationSettings } from './plus/gk/models/organization'; import type { PaidSubscriptionPlanIds, Subscription } from './plus/gk/models/subscription'; import type { IntegrationConnectedKey } from './plus/integrations/models/integration'; import type { DeepLinkServiceState } from './uris/deepLinks/deepLink'; @@ -89,6 +90,10 @@ export type GlobalStorage = { 'views:scm:grouped:welcome:dismissed': boolean; 'integrations:configured': StoredIntegrationConfigurations; } & { [key in `plus:preview:${FeaturePreviews}:usages`]: StoredFeaturePreviewUsagePeriod[] } & { + [key in `plus:organization:${string}:settings`]: Stored< + (OrganizationSettings & { lastValidatedAt: number }) | undefined + >; +} & { [key in `provider:authentication:skip:${string}`]: boolean; } & { [key in `gk:${string}:checkin`]: Stored } & { [key in `gk:${string}:organizations`]: Stored; diff --git a/src/plus/ai/aiProviderService.ts b/src/plus/ai/aiProviderService.ts index 4fa9291268104..3af8a5951fe4c 100644 --- a/src/plus/ai/aiProviderService.ts +++ b/src/plus/ai/aiProviderService.ts @@ -66,7 +66,7 @@ import type { PromptTemplateType, } from './models/promptTemplates'; import type { AIChatMessage, AIProvider, AIRequestResult } from './models/provider'; -import { ensureAccess } from './utils/-webview/ai.utils'; +import { ensureAccess, getOrgAIConfig, isProviderEnabledByOrg } from './utils/-webview/ai.utils'; import { getLocalPromptTemplate, resolvePrompt } from './utils/-webview/prompt.utils'; export interface AIResult { @@ -325,12 +325,16 @@ export class AIProviderService implements Disposable { let chosenProviderId: AIProviders | undefined; let chosenModel: AIModel | undefined; + const orgAiConf = getOrgAIConfig(); if (!options?.force) { const vsCodeModels = await this.getModels('vscode'); - if (vsCodeModels.length !== 0) { + if (isProviderEnabledByOrg('vscode', orgAiConf) && vsCodeModels.length !== 0) { chosenProviderId = 'vscode'; - } else if ((await this.container.subscription.getSubscription()).account?.verified) { + } else if ( + isProviderEnabledByOrg('gitkraken', orgAiConf) && + (await this.container.subscription.getSubscription()).account?.verified + ) { chosenProviderId = 'gitkraken'; const gitkrakenModels = await this.getModels('gitkraken'); chosenModel = gitkrakenModels.find(m => m.default); @@ -379,9 +383,10 @@ export class AIProviderService implements Disposable { } async getProvidersConfiguration(): Promise> { + const orgAiConfig = getOrgAIConfig(); const promises = await Promise.allSettled( map( - supportedAIProviders.values(), + [...supportedAIProviders.values()].filter(p => isProviderEnabledByOrg(p.id, orgAiConfig)), async p => [ p.id, @@ -421,6 +426,12 @@ export class AIProviderService implements Disposable { providerId = model.provider.id; } + if (providerId && !isProviderEnabledByOrg(providerId)) { + this._provider = undefined; + this._model = undefined; + return undefined; + } + let changed = false; if (providerId !== this._provider?.id) { diff --git a/src/plus/ai/azureProvider.ts b/src/plus/ai/azureProvider.ts index 3cb575e66bd7f..09a372f820b0d 100644 --- a/src/plus/ai/azureProvider.ts +++ b/src/plus/ai/azureProvider.ts @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration'; import type { AIActionType, AIModel } from './models/model'; import { openAIModels } from './models/model'; import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase'; -import { isAzureUrl } from './utils/-webview/ai.utils'; +import { ensureOrgConfiguredUrl, getOrgAIProviderOfType, isAzureUrl } from './utils/-webview/ai.utils'; type AzureModel = AIModel; const models: AzureModel[] = openAIModels(provider); @@ -24,10 +24,14 @@ export class AzureProvider extends OpenAICompatibleProviderBase): string | undefined { - return configuration.get('ai.azure.url') ?? undefined; + return ensureOrgConfiguredUrl(this.id, configuration.get('ai.azure.url')); } private async getOrPromptBaseUrl(silent: boolean, hasApiKey: boolean): Promise { + const orgConf = getOrgAIProviderOfType(this.id); + if (!orgConf.enabled) return undefined; + if (orgConf.url) return orgConf.url; + let url: string | undefined = this.getUrl(); if (silent || (url != null && hasApiKey)) return url; diff --git a/src/plus/ai/ollamaProvider.ts b/src/plus/ai/ollamaProvider.ts index 7c02e35c44c51..d3ae6ca23d3ab 100644 --- a/src/plus/ai/ollamaProvider.ts +++ b/src/plus/ai/ollamaProvider.ts @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration'; import type { AIActionType, AIModel } from './models/model'; import type { AIChatMessage, AIRequestResult } from './models/provider'; import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase'; -import { ensureAccount } from './utils/-webview/ai.utils'; +import { ensureAccount, ensureOrgConfiguredUrl, getOrgAIProviderOfType } from './utils/-webview/ai.utils'; type OllamaModel = AIModel; @@ -20,8 +20,12 @@ export class OllamaProvider extends OpenAICompatibleProviderBase { + const url = await this.getOrPromptBaseUrl(silent); + if (url === undefined) { + return false; + } // Ollama doesn't require an API key, but we'll check if the base URL is reachable - return this.validateUrl(await this.getOrPromptBaseUrl(silent), silent); + return this.validateUrl(url, silent); } override async getApiKey(silent: boolean): Promise { @@ -77,7 +81,11 @@ export class OllamaProvider extends OpenAICompatibleProviderBase { + private async getOrPromptBaseUrl(silent: boolean): Promise { + const orgConf = getOrgAIProviderOfType(this.id); + if (!orgConf.enabled) return undefined; + if (orgConf.url) return orgConf.url; + let url = configuration.get('ai.ollama.url') ?? undefined; if (url) { if (silent) return url; @@ -169,13 +177,14 @@ export class OllamaProvider extends OpenAICompatibleProviderBase): string { - return `${this.getBaseUrl()}/api/chat`; + protected getUrl(_model: AIModel): string | undefined { + const url = this.getBaseUrl(); + return url ? `${url}/api/chat` : undefined; } protected override getHeaders( diff --git a/src/plus/ai/openAICompatibleProvider.ts b/src/plus/ai/openAICompatibleProvider.ts index b11bb4b0ee458..3b18fad85f358 100644 --- a/src/plus/ai/openAICompatibleProvider.ts +++ b/src/plus/ai/openAICompatibleProvider.ts @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration'; import type { AIModel } from './models/model'; import { openAIModels } from './models/model'; import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase'; -import { isAzureUrl } from './utils/-webview/ai.utils'; +import { ensureOrgConfiguredUrl, getOrgAIProviderOfType, isAzureUrl } from './utils/-webview/ai.utils'; type OpenAICompatibleModel = AIModel; const models: OpenAICompatibleModel[] = openAIModels(provider); @@ -24,10 +24,14 @@ export class OpenAICompatibleProvider extends OpenAICompatibleProviderBase): string | undefined { - return configuration.get('ai.openaicompatible.url') ?? undefined; + return ensureOrgConfiguredUrl(this.id, configuration.get('ai.openaicompatible.url')); } private async getOrPromptBaseUrl(silent: boolean, hasApiKey: boolean): Promise { + const orgConf = getOrgAIProviderOfType(this.id); + if (!orgConf.enabled) return undefined; + if (orgConf.url) return orgConf.url; + let url: string | undefined = this.getUrl(); if (silent || (url != null && hasApiKey)) return url; diff --git a/src/plus/ai/openAICompatibleProviderBase.ts b/src/plus/ai/openAICompatibleProviderBase.ts index 0197de3aca0a0..79ba510e1c9d7 100644 --- a/src/plus/ai/openAICompatibleProviderBase.ts +++ b/src/plus/ai/openAICompatibleProviderBase.ts @@ -10,7 +10,12 @@ import { startLogScope } from '../../system/logger.scope'; import type { ServerConnection } from '../gk/serverConnection'; import type { AIActionType, AIModel, AIProviderDescriptor } from './models/model'; import type { AIChatMessage, AIChatMessageRole, AIProvider, AIRequestResult } from './models/provider'; -import { getActionName, getOrPromptApiKey, getValidatedTemperature } from './utils/-webview/ai.utils'; +import { + getActionName, + getOrgAIProviderOfType, + getOrPromptApiKey, + getValidatedTemperature, +} from './utils/-webview/ai.utils'; export interface AIProviderConfig { url: string; @@ -36,6 +41,10 @@ export abstract class OpenAICompatibleProviderBase implem } async getApiKey(silent: boolean): Promise { + const orgConf = getOrgAIProviderOfType(this.id); + if (!orgConf.enabled) return undefined; + if (orgConf.key) return orgConf.key; + const { keyUrl, keyValidator } = this.config; return getOrPromptApiKey( diff --git a/src/plus/ai/openaiProvider.ts b/src/plus/ai/openaiProvider.ts index cc55d810b450f..6aa6cf4b3478d 100644 --- a/src/plus/ai/openaiProvider.ts +++ b/src/plus/ai/openaiProvider.ts @@ -3,7 +3,7 @@ import { configuration } from '../../system/-webview/configuration'; import type { AIActionType, AIModel } from './models/model'; import { openAIModels } from './models/model'; import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase'; -import { isAzureUrl } from './utils/-webview/ai.utils'; +import { ensureOrgConfiguredUrl, isAzureUrl } from './utils/-webview/ai.utils'; type OpenAIModel = AIModel; const models: OpenAIModel[] = openAIModels(provider); @@ -22,7 +22,10 @@ export class OpenAIProvider extends OpenAICompatibleProviderBase): string { - return configuration.get('ai.openai.url') || 'https://api.openai.com/v1/chat/completions'; + return ( + ensureOrgConfiguredUrl(this.id, configuration.get('ai.openai.url')) || + 'https://api.openai.com/v1/chat/completions' + ); } protected override getHeaders( diff --git a/src/plus/ai/utils/-webview/ai.utils.ts b/src/plus/ai/utils/-webview/ai.utils.ts index 4857e63f9f686..a9f0ca74720b6 100644 --- a/src/plus/ai/utils/-webview/ai.utils.ts +++ b/src/plus/ai/utils/-webview/ai.utils.ts @@ -8,6 +8,7 @@ import { getContext } from '../../../../system/-webview/context'; import { openSettingsEditor } from '../../../../system/-webview/vscode/editors'; import { formatNumeric } from '../../../../system/date'; import { getPossessiveForm, pluralize } from '../../../../system/string'; +import type { OrgAIConfig, OrgAIProvider } from '../../../gk/models/organization'; import { ensureAccountQuickPick } from '../../../gk/utils/-webview/acount.utils'; import type { AIActionType, AIModel } from '../../models/model'; @@ -170,6 +171,35 @@ export function isAzureUrl(url: string): boolean { return url.includes('.azure.com'); } +export function getOrgAIConfig(): OrgAIConfig { + return { + aiEnabled: getContext('gitlens:gk:organization:ai:enabled', true), + enforceAiProviders: getContext('gitlens:gk:organization:ai:enforceProviders', false), + aiProviders: getContext('gitlens:gk:organization:ai:providers', {}), + }; +} + +export function getOrgAIProviderOfType(type: AIProviders, orgAiConfig?: OrgAIConfig): OrgAIProvider { + orgAiConfig ??= getOrgAIConfig(); + if (!orgAiConfig.aiEnabled) return { type: type, enabled: false }; + if (!orgAiConfig.enforceAiProviders) return { type: type, enabled: true }; + return orgAiConfig.aiProviders[type] ?? { type: type, enabled: false }; +} + +export function isProviderEnabledByOrg(type: AIProviders, orgAiConfig?: OrgAIConfig): boolean { + return getOrgAIProviderOfType(type, orgAiConfig).enabled; +} + +/** + * If the input value (userUrl) matches to the org configuration it returns it. + */ +export function ensureOrgConfiguredUrl(type: AIProviders, userUrl: null | undefined | string): string | undefined { + const provider = getOrgAIProviderOfType(type); + if (!provider.enabled) return undefined; + + return provider.url || userUrl || undefined; +} + export async function ensureAccess(options?: { showPicker?: boolean }): Promise { const showPicker = options?.showPicker ?? false; diff --git a/src/plus/gk/models/organization.ts b/src/plus/gk/models/organization.ts index 7507af9a73ff4..dcc7e2e234c52 100644 --- a/src/plus/gk/models/organization.ts +++ b/src/plus/gk/models/organization.ts @@ -1,3 +1,5 @@ +import type { AIProviders } from '../../../constants.ai'; + export interface Organization { readonly id: string; readonly name: string; @@ -20,7 +22,10 @@ export interface OrganizationMember { } export interface OrganizationSettings { + aiEnabled: boolean; + enforceAiProviders: boolean; aiSettings: OrganizationSetting; + aiProviders: GkDevAIProviders; draftsSettings: OrganizationDraftsSettings; } @@ -39,3 +44,74 @@ export interface OrganizationDraftsSettings extends OrganizationSetting { } | undefined; } + +export type GkDevAIProviders = Partial>; + +export interface GkDevAIProvider { + enabled: boolean; + url?: string; + key?: string; +} + +export interface OrgAIProvider { + readonly type: AIProviders; + readonly enabled: boolean; + readonly url?: string; + readonly key?: string; +} + +export type OrgAIProviders = Partial>; +export interface OrgAIConfig { + readonly aiEnabled: boolean; + readonly enforceAiProviders: boolean; + readonly aiProviders: OrgAIProviders; +} + +export type GkDevAIProviderType = 'anthropic' | 'azure' | 'gitkraken_ai' | 'openai' | 'openai_compatible'; + +export function fromGkDevAIProviderType(type: GkDevAIProviderType): AIProviders; +export function fromGkDevAIProviderType(type: Exclude): never; +export function fromGkDevAIProviderType(type: unknown): AIProviders | never { + switch (type) { + case 'anthropic': + return 'anthropic'; + case 'azure': + return 'azure'; + case 'gitkraken_ai': + return 'gitkraken'; + case 'openai': + return 'openai'; + case 'openai_compatible': + return 'openaicompatible'; + case 'ollama': + return 'ollama'; + default: + throw new Error(`Unknown AI provider type: ${String(type)}`); + } +} + +function fromGkDevAIProvider(type: GkDevAIProviderType, provider: GkDevAIProvider): OrgAIProvider { + return { + type: fromGkDevAIProviderType(type), + enabled: provider.enabled, + url: provider.url, + key: provider.key, + }; +} + +export function fromGKDevAIProviders(providers?: GkDevAIProviders): OrgAIProviders { + const result: OrgAIProviders = {}; + if (providers == null) return result; + + Object.entries(providers).forEach(([type, provider]) => { + try { + result[fromGkDevAIProviderType(type as GkDevAIProviderType)] = fromGkDevAIProvider( + type as GkDevAIProviderType, + provider, + ); + } catch { + // ignore invalid provider, continue with others + } + }); + return result; +} diff --git a/src/plus/gk/organizationService.ts b/src/plus/gk/organizationService.ts index 24e386b3e54d5..9e5547eb22a44 100644 --- a/src/plus/gk/organizationService.ts +++ b/src/plus/gk/organizationService.ts @@ -11,6 +11,7 @@ import type { OrganizationSettings, OrganizationsResponse, } from './models/organization'; +import { fromGKDevAIProviders } from './models/organization'; import type { ServerConnection } from './serverConnection'; import type { SubscriptionChangeEvent } from './subscriptionService'; @@ -19,7 +20,9 @@ const organizationsCacheExpiration = 24 * 60 * 60 * 1000; // 1 day export class OrganizationService implements Disposable { private _disposable: Disposable; private _organizations: Organization[] | null | undefined; - private _organizationSettings: Map | undefined; + private _organizationSettings: + | Map + | undefined; private _organizationMembers: Map | undefined; constructor( @@ -125,11 +128,12 @@ export class OrganizationService implements Disposable { }); } - private onSubscriptionChanged(e: SubscriptionChangeEvent): void { + private async onSubscriptionChanged(e: SubscriptionChangeEvent): Promise { if (e.current?.account?.id == null) { this.updateOrganizations(undefined); } - void this.updateOrganizationPermissions(e.current?.activeOrganization?.id); + await this.clearAllStoredOrganizationsSettings(); + await this.updateOrganizationPermissions(e.current?.activeOrganization?.id); } private updateOrganizations(organizations: Organization[] | null | undefined): void { @@ -139,8 +143,25 @@ export class OrganizationService implements Disposable { private async updateOrganizationPermissions(orgId: string | undefined): Promise { const settings = orgId != null ? await this.getOrganizationSettings(orgId) : undefined; + let aiProviders; + try { + aiProviders = fromGKDevAIProviders(settings?.aiProviders); + } catch { + aiProviders = {}; + if (settings) { + settings.enforceAiProviders = false; + } + } + + const enforceAiProviders = settings?.enforceAiProviders ?? false; + const disabledByEnforcing = enforceAiProviders && !Object.values(aiProviders).some(p => p.enabled); - void setContext('gitlens:gk:organization:ai:enabled', settings?.aiSettings.enabled ?? true); + void setContext( + 'gitlens:gk:organization:ai:enabled', + (!disabledByEnforcing && settings?.aiSettings.enabled) ?? settings?.aiEnabled ?? true, + ); + void setContext('gitlens:gk:organization:ai:enforceProviders', enforceAiProviders); + void setContext('gitlens:gk:organization:ai:providers', aiProviders); void setContext('gitlens:gk:organization:drafts:byob', settings?.draftsSettings.bucket != null); void setContext('gitlens:gk:organization:drafts:enabled', settings?.draftsSettings.enabled ?? true); } @@ -202,7 +223,23 @@ export class OrganizationService implements Disposable { const id = orgId ?? (await this.getActiveOrganizationId()); if (id == null) return undefined; + if (!options?.force && !this._organizationSettings?.has(id)) { + const cachedOrg = this.getStoredOrganizationSettings(id); + if (cachedOrg) { + this._organizationSettings ??= new Map(); + this._organizationSettings.set(id, cachedOrg); + } + } + + if (this._organizationSettings?.has(id)) { + const org = this._organizationSettings.get(id); + if (org && Date.now() - org.lastValidatedDate.getTime() > organizationsCacheExpiration) { + this._organizationSettings.delete(id); + } + } + if (!this._organizationSettings?.has(id) || options?.force === true) { + await this.deleteStoredOrganizationSettings(id); const rsp = await this.connection.fetchGkApi( `v1/organizations/settings`, { method: 'GET' }, @@ -230,9 +267,43 @@ export class OrganizationService implements Disposable { if (this._organizationSettings == null) { this._organizationSettings = new Map(); } - this._organizationSettings.set(id, organizationResponse.data); + this._organizationSettings.set(id, { data: organizationResponse.data, lastValidatedDate: new Date() }); + await this.storeOrganizationSettings(id, organizationResponse.data, new Date()); } - return this._organizationSettings.get(id); + return this._organizationSettings.get(id)?.data; + } + + private async clearAllStoredOrganizationsSettings(): Promise { + return this.container.storage.deleteWithPrefix(`plus:organization:`); + } + + private async deleteStoredOrganizationSettings(id: string): Promise { + return this.container.storage.delete(`plus:organization:${id}:settings`); + } + + private getStoredOrganizationSettings( + id: string, + ): { data: OrganizationSettings; lastValidatedDate: Date } | undefined { + const result = this.container.storage.get(`plus:organization:${id}:settings`); + if (!result?.data) return undefined; + + const { lastValidatedAt, ...organizationSettings } = result.data; + + return { + data: organizationSettings, + lastValidatedDate: new Date(lastValidatedAt), + }; + } + + private async storeOrganizationSettings( + id: string, + settings: OrganizationSettings, + lastValidatedDate: Date, + ): Promise { + return this.container.storage.store(`plus:organization:${id}:settings`, { + v: 1, + data: { ...settings, lastValidatedAt: lastValidatedDate.getTime() }, + }); } }