Skip to content

Commit ecaa83a

Browse files
committed
Adds support to specify model temperature
1 parent f81c2a9 commit ecaa83a

File tree

7 files changed

+58
-9
lines changed

7 files changed

+58
-9
lines changed

package.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3897,6 +3897,18 @@
38973897
"preview"
38983898
]
38993899
},
3900+
"gitlens.ai.modelOptions.temperature": {
3901+
"type": "number",
3902+
"default": 0.7,
3903+
"minimum": 0,
3904+
"maximum": 2,
3905+
"markdownDescription": "Specifies the temperature, a measure of output randomness, to use for the AI model. Higher values result in more randomness, e.g. creativity, while lower values are more deterministic",
3906+
"scope": "window",
3907+
"order": 90,
3908+
"tags": [
3909+
"preview"
3910+
]
3911+
},
39003912
"gitlens.ai.explainChanges.customInstructions": {
39013913
"type": "string",
39023914
"default": null,

src/ai/aiProviderService.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ export interface AIModel<
4545

4646
readonly default?: boolean;
4747
readonly hidden?: boolean;
48+
49+
readonly temperature?: number;
4850
}
4951

5052
interface AIProviderConstructor<Provider extends AIProviders = AIProviders> {
@@ -705,3 +707,7 @@ export function showDiffTruncationWarning(maxCodeCharacters: number, model: AIMo
705707
)} limits.`,
706708
);
707709
}
710+
711+
export function getValidatedTemperature(): number {
712+
return Math.max(0, Math.min(configuration.get('ai.modelOptions.temperature'), 2));
713+
}

src/ai/anthropicProvider.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { CancellationToken } from 'vscode';
22
import type { AIModel } from './aiProviderService';
3-
import { getMaxCharacters } from './aiProviderService';
3+
import { getMaxCharacters, getValidatedTemperature } from './aiProviderService';
44
import type { ChatMessage } from './openAICompatibleProvider';
55
import { OpenAICompatibleProvider } from './openAICompatibleProvider';
66

@@ -119,6 +119,7 @@ export class AnthropicProvider extends OpenAICompatibleProvider<typeof provider.
119119
system: system.content,
120120
stream: false,
121121
max_tokens: Math.min(outputTokens, model.maxTokens.output),
122+
temperature: model.temperature ?? getValidatedTemperature(),
122123
};
123124

124125
const rsp = await this.fetchCore(model, apiKey, request, cancellation);

src/ai/geminiProvider.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ export class GeminiProvider extends OpenAICompatibleProvider<typeof provider.id>
6666
request: object,
6767
cancellation: CancellationToken | undefined,
6868
) {
69-
if ('max_tokens' in request) {
70-
const { max_tokens: _, ...rest } = request;
69+
if ('max_completion_tokens' in request) {
70+
const { max_completion_tokens: _, ...rest } = request;
7171
request = rest;
7272
}
7373
return super.fetchCore(model, apiKey, request, cancellation);

src/ai/openAICompatibleProvider.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ import { configuration } from '../system/-webview/configuration';
88
import { sum } from '../system/iterable';
99
import { interpolate } from '../system/string';
1010
import type { AIModel, AIProvider } from './aiProviderService';
11-
import { getMaxCharacters, getOrPromptApiKey, showDiffTruncationWarning } from './aiProviderService';
11+
import {
12+
getMaxCharacters,
13+
getOrPromptApiKey,
14+
getValidatedTemperature,
15+
showDiffTruncationWarning,
16+
} from './aiProviderService';
1217
import {
1318
explainChangesUserPrompt,
1419
generateCloudPatchMessageUserPrompt,
@@ -215,7 +220,8 @@ export abstract class OpenAICompatibleProvider<T extends AIProviders> implements
215220
model: model.id,
216221
messages: messages(maxCodeCharacters, retries),
217222
stream: false,
218-
max_tokens: Math.min(outputTokens, model.maxTokens.output),
223+
max_completion_tokens: Math.min(outputTokens, model.maxTokens.output),
224+
temperature: model.temperature ?? getValidatedTemperature(),
219225
};
220226

221227
const rsp = await this.fetchCore(model, apiKey, request, cancellation);
@@ -295,7 +301,7 @@ interface ChatCompletionRequest {
295301

296302
frequency_penalty?: number;
297303
logit_bias?: Record<string, number>;
298-
max_tokens?: number;
304+
max_completion_tokens?: number;
299305
n?: number;
300306
presence_penalty?: number;
301307
stop?: string | string[];

src/ai/vscodeProvider.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { configuration } from '../system/-webview/configuration';
66
import { sum } from '../system/iterable';
77
import { capitalize, getPossessiveForm, interpolate } from '../system/string';
88
import type { AIModel, AIProvider } from './aiProviderService';
9-
import { getMaxCharacters, showDiffTruncationWarning } from './aiProviderService';
9+
import { getMaxCharacters, getValidatedTemperature, showDiffTruncationWarning } from './aiProviderService';
1010
import {
1111
explainChangesUserPrompt,
1212
generateCloudPatchMessageUserPrompt,
@@ -21,6 +21,9 @@ export function isVSCodeAIModel(model: AIModel): model is AIModel<typeof provide
2121
return model.provider.id === provider.id;
2222
}
2323

24+
const accessJustification =
25+
'GitLens leverages Copilot for AI-powered features to improve your workflow and development experience.';
26+
2427
export class VSCodeAIProvider implements AIProvider<typeof provider.id> {
2528
readonly id = provider.id;
2629

@@ -85,7 +88,14 @@ export class VSCodeAIProvider implements AIProvider<typeof provider.id> {
8588
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(messages, m => m.content.length);
8689

8790
try {
88-
const rsp = await chatModel.sendRequest(messages, {}, cancellation);
91+
const rsp = await chatModel.sendRequest(
92+
messages,
93+
{
94+
justification: accessJustification,
95+
modelOptions: { temperature: model.temperature ?? getValidatedTemperature() },
96+
},
97+
cancellation,
98+
);
8999

90100
if (diff.length > maxCodeCharacters) {
91101
showDiffTruncationWarning(maxCodeCharacters, model);
@@ -219,7 +229,14 @@ export class VSCodeAIProvider implements AIProvider<typeof provider.id> {
219229
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(messages, m => m.content.length);
220230

221231
try {
222-
const rsp = await chatModel.sendRequest(messages, {}, cancellation);
232+
const rsp = await chatModel.sendRequest(
233+
messages,
234+
{
235+
justification: accessJustification,
236+
modelOptions: model.temperature != null ? { temperature: model.temperature } : undefined,
237+
},
238+
cancellation,
239+
);
223240

224241
if (diff.length > maxCodeCharacters) {
225242
showDiffTruncationWarning(maxCodeCharacters, model);
@@ -235,6 +252,10 @@ export class VSCodeAIProvider implements AIProvider<typeof provider.id> {
235252
debugger;
236253
let message = ex instanceof Error ? ex.message : String(ex);
237254

255+
if (ex instanceof Error && 'code' in ex && ex.code === 'NoPermissions') {
256+
throw new Error(`User denied access to ${model.provider.name}`);
257+
}
258+
238259
if (ex instanceof Error && 'cause' in ex && ex.cause instanceof Error) {
239260
message += `\n${ex.cause.message}`;
240261

src/config.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ interface AIConfig {
222222
readonly customInstructions: string;
223223
};
224224
readonly model: SupportedAIModels | null;
225+
readonly modelOptions: {
226+
readonly temperature: number;
227+
};
225228
readonly openai: {
226229
readonly url: string | null;
227230
};

0 commit comments

Comments
 (0)