Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/constants.context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { AnnotationStatus, Keys } from './constants';
import type { SubscriptionState } from './constants.subscription';
import type { CustomEditorTypes, GroupableTreeViewTypes, WebviewTypes, WebviewViewTypes } from './constants.views';
import type { Features } from './features';
import type { OrgAIProviders } from './plus/gk/models/organization';
import type { PromoKeys } from './plus/gk/models/promo';
import type { SubscriptionPlanIds } from './plus/gk/models/subscription';
import type { WalkthroughContextKeys } from './telemetry/walkthroughStateProvider';
Expand All @@ -14,6 +15,8 @@ export type ContextKeys = {
'gitlens:enabled': boolean;
'gitlens:gk:hasOrganizations': boolean;
'gitlens:gk:organization:ai:enabled': boolean;
'gitlens:gk:organization:ai:enforceProviders': boolean;
'gitlens:gk:organization:ai:providers': OrgAIProviders;
'gitlens:gk:organization:drafts:byob': boolean;
'gitlens:gk:organization:drafts:enabled': boolean;
'gitlens:hasVirtualFolders': boolean;
Expand Down
5 changes: 5 additions & 0 deletions src/constants.storage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { GroupableTreeViewTypes } from './constants.views';
import type { Environment } from './container';
import type { FeaturePreviews } from './features';
import type { GitRevisionRangeNotation } from './git/models/revision';
import type { OrganizationSettings } from './plus/gk/models/organization';
import type { PaidSubscriptionPlanIds, Subscription } from './plus/gk/models/subscription';
import type { IntegrationConnectedKey } from './plus/integrations/models/integration';
import type { DeepLinkServiceState } from './uris/deepLinks/deepLink';
Expand Down Expand Up @@ -89,6 +90,10 @@ export type GlobalStorage = {
'views:scm:grouped:welcome:dismissed': boolean;
'integrations:configured': StoredIntegrationConfigurations;
} & { [key in `plus:preview:${FeaturePreviews}:usages`]: StoredFeaturePreviewUsagePeriod[] } & {
[key in `plus:organization:${string}:settings`]: Stored<
(OrganizationSettings & { lastValidatedAt: number }) | undefined
>;
} & {
[key in `provider:authentication:skip:${string}`]: boolean;
} & { [key in `gk:${string}:checkin`]: Stored<StoredGKCheckInResponse> } & {
[key in `gk:${string}:organizations`]: Stored<StoredOrganization[]>;
Expand Down
19 changes: 15 additions & 4 deletions src/plus/ai/aiProviderService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import type {
PromptTemplateType,
} from './models/promptTemplates';
import type { AIChatMessage, AIProvider, AIRequestResult } from './models/provider';
import { ensureAccess } from './utils/-webview/ai.utils';
import { ensureAccess, getOrgAIConfig, isProviderEnabledByOrg } from './utils/-webview/ai.utils';
import { getLocalPromptTemplate, resolvePrompt } from './utils/-webview/prompt.utils';

export interface AIResult {
Expand Down Expand Up @@ -325,12 +325,16 @@ export class AIProviderService implements Disposable {

let chosenProviderId: AIProviders | undefined;
let chosenModel: AIModel | undefined;
const orgAiConf = getOrgAIConfig();

if (!options?.force) {
const vsCodeModels = await this.getModels('vscode');
if (vsCodeModels.length !== 0) {
if (isProviderEnabledByOrg('vscode', orgAiConf) && vsCodeModels.length !== 0) {
chosenProviderId = 'vscode';
} else if ((await this.container.subscription.getSubscription()).account?.verified) {
} else if (
isProviderEnabledByOrg('gitkraken', orgAiConf) &&
(await this.container.subscription.getSubscription()).account?.verified
) {
chosenProviderId = 'gitkraken';
const gitkrakenModels = await this.getModels('gitkraken');
chosenModel = gitkrakenModels.find(m => m.default);
Expand Down Expand Up @@ -379,9 +383,10 @@ export class AIProviderService implements Disposable {
}

async getProvidersConfiguration(): Promise<Map<AIProviders, AIProviderDescriptorWithConfiguration>> {
const orgAiConfig = getOrgAIConfig();
const promises = await Promise.allSettled(
map(
supportedAIProviders.values(),
[...supportedAIProviders.values()].filter(p => isProviderEnabledByOrg(p.id, orgAiConfig)),
async p =>
[
p.id,
Expand Down Expand Up @@ -421,6 +426,12 @@ export class AIProviderService implements Disposable {
providerId = model.provider.id;
}

if (providerId && !isProviderEnabledByOrg(providerId)) {
this._provider = undefined;
this._model = undefined;
return undefined;
}

let changed = false;

if (providerId !== this._provider?.id) {
Expand Down
8 changes: 6 additions & 2 deletions src/plus/ai/azureProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration';
import type { AIActionType, AIModel } from './models/model';
import { openAIModels } from './models/model';
import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase';
import { isAzureUrl } from './utils/-webview/ai.utils';
import { ensureOrgConfiguredUrl, getOrgAIProviderOfType, isAzureUrl } from './utils/-webview/ai.utils';

type AzureModel = AIModel<typeof provider.id>;
const models: AzureModel[] = openAIModels(provider);
Expand All @@ -24,10 +24,14 @@ export class AzureProvider extends OpenAICompatibleProviderBase<typeof provider.
}

protected getUrl(_model?: AIModel<typeof provider.id>): string | undefined {
return configuration.get('ai.azure.url') ?? undefined;
return ensureOrgConfiguredUrl(this.id, configuration.get('ai.azure.url'));
}

private async getOrPromptBaseUrl(silent: boolean, hasApiKey: boolean): Promise<string | undefined> {
const orgConf = getOrgAIProviderOfType(this.id);
if (!orgConf.enabled) return undefined;
if (orgConf.url) return orgConf.url;

let url: string | undefined = this.getUrl();

if (silent || (url != null && hasApiKey)) return url;
Expand Down
23 changes: 16 additions & 7 deletions src/plus/ai/ollamaProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration';
import type { AIActionType, AIModel } from './models/model';
import type { AIChatMessage, AIRequestResult } from './models/provider';
import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase';
import { ensureAccount } from './utils/-webview/ai.utils';
import { ensureAccount, ensureOrgConfiguredUrl, getOrgAIProviderOfType } from './utils/-webview/ai.utils';

type OllamaModel = AIModel<typeof provider.id>;

Expand All @@ -20,8 +20,12 @@ export class OllamaProvider extends OpenAICompatibleProviderBase<typeof provider
};

override async configured(silent: boolean): Promise<boolean> {
const url = await this.getOrPromptBaseUrl(silent);
if (url === undefined) {
return false;
}
// Ollama doesn't require an API key, but we'll check if the base URL is reachable
return this.validateUrl(await this.getOrPromptBaseUrl(silent), silent);
return this.validateUrl(url, silent);
}

override async getApiKey(silent: boolean): Promise<string | undefined> {
Expand Down Expand Up @@ -77,7 +81,11 @@ export class OllamaProvider extends OpenAICompatibleProviderBase<typeof provider
return [];
}

private async getOrPromptBaseUrl(silent: boolean): Promise<string> {
private async getOrPromptBaseUrl(silent: boolean): Promise<string | undefined> {
const orgConf = getOrgAIProviderOfType(this.id);
if (!orgConf.enabled) return undefined;
if (orgConf.url) return orgConf.url;

let url = configuration.get('ai.ollama.url') ?? undefined;
if (url) {
if (silent) return url;
Expand Down Expand Up @@ -169,13 +177,14 @@ export class OllamaProvider extends OpenAICompatibleProviderBase<typeof provider
}
}

private getBaseUrl(): string {
private getBaseUrl(): string | undefined {
// Get base URL from configuration or use default
return configuration.get('ai.ollama.url') || defaultBaseUrl;
return ensureOrgConfiguredUrl(this.id, configuration.get('ai.ollama.url') || defaultBaseUrl);
}

protected getUrl(_model: AIModel<typeof provider.id>): string {
return `${this.getBaseUrl()}/api/chat`;
protected getUrl(_model: AIModel<typeof provider.id>): string | undefined {
const url = this.getBaseUrl();
return url ? `${url}/api/chat` : undefined;
}

protected override getHeaders<TAction extends AIActionType>(
Expand Down
8 changes: 6 additions & 2 deletions src/plus/ai/openAICompatibleProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { configuration } from '../../system/-webview/configuration';
import type { AIModel } from './models/model';
import { openAIModels } from './models/model';
import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase';
import { isAzureUrl } from './utils/-webview/ai.utils';
import { ensureOrgConfiguredUrl, getOrgAIProviderOfType, isAzureUrl } from './utils/-webview/ai.utils';

type OpenAICompatibleModel = AIModel<typeof provider.id>;
const models: OpenAICompatibleModel[] = openAIModels(provider);
Expand All @@ -24,10 +24,14 @@ export class OpenAICompatibleProvider extends OpenAICompatibleProviderBase<typeo
}

protected getUrl(_model?: AIModel<typeof provider.id>): string | undefined {
return configuration.get('ai.openaicompatible.url') ?? undefined;
return ensureOrgConfiguredUrl(this.id, configuration.get('ai.openaicompatible.url'));
}

private async getOrPromptBaseUrl(silent: boolean, hasApiKey: boolean): Promise<string | undefined> {
const orgConf = getOrgAIProviderOfType(this.id);
if (!orgConf.enabled) return undefined;
if (orgConf.url) return orgConf.url;

let url: string | undefined = this.getUrl();

if (silent || (url != null && hasApiKey)) return url;
Expand Down
11 changes: 10 additions & 1 deletion src/plus/ai/openAICompatibleProviderBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import { startLogScope } from '../../system/logger.scope';
import type { ServerConnection } from '../gk/serverConnection';
import type { AIActionType, AIModel, AIProviderDescriptor } from './models/model';
import type { AIChatMessage, AIChatMessageRole, AIProvider, AIRequestResult } from './models/provider';
import { getActionName, getOrPromptApiKey, getValidatedTemperature } from './utils/-webview/ai.utils';
import {
getActionName,
getOrgAIProviderOfType,
getOrPromptApiKey,
getValidatedTemperature,
} from './utils/-webview/ai.utils';

export interface AIProviderConfig {
url: string;
Expand All @@ -36,6 +41,10 @@ export abstract class OpenAICompatibleProviderBase<T extends AIProviders> implem
}

async getApiKey(silent: boolean): Promise<string | undefined> {
const orgConf = getOrgAIProviderOfType(this.id);
if (!orgConf.enabled) return undefined;
if (orgConf.key) return orgConf.key;

const { keyUrl, keyValidator } = this.config;

return getOrPromptApiKey(
Expand Down
7 changes: 5 additions & 2 deletions src/plus/ai/openaiProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { configuration } from '../../system/-webview/configuration';
import type { AIActionType, AIModel } from './models/model';
import { openAIModels } from './models/model';
import { OpenAICompatibleProviderBase } from './openAICompatibleProviderBase';
import { isAzureUrl } from './utils/-webview/ai.utils';
import { ensureOrgConfiguredUrl, isAzureUrl } from './utils/-webview/ai.utils';

type OpenAIModel = AIModel<typeof provider.id>;
const models: OpenAIModel[] = openAIModels(provider);
Expand All @@ -22,7 +22,10 @@ export class OpenAIProvider extends OpenAICompatibleProviderBase<typeof provider
}

protected getUrl(_model: AIModel<typeof provider.id>): string {
return configuration.get('ai.openai.url') || 'https://api.openai.com/v1/chat/completions';
return (
ensureOrgConfiguredUrl(this.id, configuration.get('ai.openai.url')) ||
'https://api.openai.com/v1/chat/completions'
);
}

protected override getHeaders<TAction extends AIActionType>(
Expand Down
30 changes: 30 additions & 0 deletions src/plus/ai/utils/-webview/ai.utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { getContext } from '../../../../system/-webview/context';
import { openSettingsEditor } from '../../../../system/-webview/vscode/editors';
import { formatNumeric } from '../../../../system/date';
import { getPossessiveForm, pluralize } from '../../../../system/string';
import type { OrgAIConfig, OrgAIProvider } from '../../../gk/models/organization';
import { ensureAccountQuickPick } from '../../../gk/utils/-webview/acount.utils';
import type { AIActionType, AIModel } from '../../models/model';

Expand Down Expand Up @@ -170,6 +171,35 @@ export function isAzureUrl(url: string): boolean {
return url.includes('.azure.com');
}

export function getOrgAIConfig(): OrgAIConfig {
return {
aiEnabled: getContext('gitlens:gk:organization:ai:enabled', true),
enforceAiProviders: getContext('gitlens:gk:organization:ai:enforceProviders', false),
aiProviders: getContext('gitlens:gk:organization:ai:providers', {}),
};
}

export function getOrgAIProviderOfType(type: AIProviders, orgAiConfig?: OrgAIConfig): OrgAIProvider {
orgAiConfig ??= getOrgAIConfig();
if (!orgAiConfig.aiEnabled) return { type: type, enabled: false };
if (!orgAiConfig.enforceAiProviders) return { type: type, enabled: true };
return orgAiConfig.aiProviders[type] ?? { type: type, enabled: false };
}

export function isProviderEnabledByOrg(type: AIProviders, orgAiConfig?: OrgAIConfig): boolean {
return getOrgAIProviderOfType(type, orgAiConfig).enabled;
}

/**
* If the input value (userUrl) matches to the org configuration it returns it.
*/
export function ensureOrgConfiguredUrl(type: AIProviders, userUrl: null | undefined | string): string | undefined {
const provider = getOrgAIProviderOfType(type);
if (!provider.enabled) return undefined;

return provider.url || userUrl || undefined;
}

export async function ensureAccess(options?: { showPicker?: boolean }): Promise<boolean> {
const showPicker = options?.showPicker ?? false;

Expand Down
76 changes: 76 additions & 0 deletions src/plus/gk/models/organization.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import type { AIProviders } from '../../../constants.ai';

export interface Organization {
readonly id: string;
readonly name: string;
Expand All @@ -20,7 +22,10 @@ export interface OrganizationMember {
}

export interface OrganizationSettings {
aiEnabled: boolean;
enforceAiProviders: boolean;
aiSettings: OrganizationSetting;
aiProviders: GkDevAIProviders;
draftsSettings: OrganizationDraftsSettings;
}

Expand All @@ -39,3 +44,74 @@ export interface OrganizationDraftsSettings extends OrganizationSetting {
}
| undefined;
}

export type GkDevAIProviders = Partial<Record<GkDevAIProviderType, GkDevAIProvider>>;

export interface GkDevAIProvider {
enabled: boolean;
url?: string;
key?: string;
}

export interface OrgAIProvider {
readonly type: AIProviders;
readonly enabled: boolean;
readonly url?: string;
readonly key?: string;
}

export type OrgAIProviders = Partial<Record<AIProviders, OrgAIProvider | undefined>>;
export interface OrgAIConfig {
readonly aiEnabled: boolean;
readonly enforceAiProviders: boolean;
readonly aiProviders: OrgAIProviders;
}

export type GkDevAIProviderType = 'anthropic' | 'azure' | 'gitkraken_ai' | 'openai' | 'openai_compatible';

export function fromGkDevAIProviderType(type: GkDevAIProviderType): AIProviders;
export function fromGkDevAIProviderType(type: Exclude<unknown, GkDevAIProviderType>): never;
export function fromGkDevAIProviderType(type: unknown): AIProviders | never {
switch (type) {
case 'anthropic':
return 'anthropic';
case 'azure':
return 'azure';
case 'gitkraken_ai':
return 'gitkraken';
case 'openai':
return 'openai';
case 'openai_compatible':
return 'openaicompatible';
case 'ollama':
return 'ollama';
default:
throw new Error(`Unknown AI provider type: ${String(type)}`);
}
}

function fromGkDevAIProvider(type: GkDevAIProviderType, provider: GkDevAIProvider): OrgAIProvider {
return {
type: fromGkDevAIProviderType(type),
enabled: provider.enabled,
url: provider.url,
key: provider.key,
};
}

export function fromGKDevAIProviders(providers?: GkDevAIProviders): OrgAIProviders {
const result: OrgAIProviders = {};
if (providers == null) return result;

Object.entries(providers).forEach(([type, provider]) => {
try {
result[fromGkDevAIProviderType(type as GkDevAIProviderType)] = fromGkDevAIProvider(
type as GkDevAIProviderType,
provider,
);
} catch {
// ignore invalid provider, continue with others
}
});
return result;
}
Loading