Skip to content

Commit 95ea55a

Browse files
committed
Adds xAI model support
1 parent 5927411 commit 95ea55a

File tree

4 files changed

+360
-3
lines changed

4 files changed

+360
-3
lines changed

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3714,7 +3714,8 @@
37143714
"google:gemini-1.5-pro-latest",
37153715
"google:gemini-1.5-flash-latest",
37163716
"google:gemini-1.0-pro",
3717-
"vscode"
3717+
"vscode",
3718+
"xai:grok-beta"
37183719
],
37193720
"enumDescriptions": [
37203721
"OpenAI GPT-4 Omni",

src/ai/aiProviderService.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { AnthropicProvider } from './anthropicProvider';
2020
import { GeminiProvider } from './geminiProvider';
2121
import { OpenAIProvider } from './openaiProvider';
2222
import { isVSCodeAIModel, VSCodeAIProvider } from './vscodeProvider';
23+
import { xAIProvider } from './xaiProvider';
2324

2425
export interface AIModel<
2526
Provider extends AIProviders = AIProviders,
@@ -46,6 +47,7 @@ const _supportedProviderTypes = new Map<AIProviders, AIProviderConstructor>([
4647
['openai', OpenAIProvider],
4748
['anthropic', AnthropicProvider],
4849
['gemini', GeminiProvider],
50+
['xai', xAIProvider],
4951
]);
5052

5153
export interface AIProvider<Provider extends AIProviders = AIProviders> extends Disposable {

src/ai/xaiProvider.ts

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
import { fetch } from '@env/fetch';
2+
import type { CancellationToken } from 'vscode';
3+
import { window } from 'vscode';
4+
import type { xAIModels } from '../constants.ai';
5+
import type { TelemetryEvents } from '../constants.telemetry';
6+
import type { Container } from '../container';
7+
import { CancellationError } from '../errors';
8+
import { sum } from '../system/iterable';
9+
import { interpolate } from '../system/string';
10+
import { configuration } from '../system/vscode/configuration';
11+
import type { Storage } from '../system/vscode/storage';
12+
import type { AIModel, AIProvider } from './aiProviderService';
13+
import { getApiKey as getApiKeyCore, getMaxCharacters } from './aiProviderService';
14+
import {
15+
generateCloudPatchMessageSystemPrompt,
16+
generateCloudPatchMessageUserPrompt,
17+
generateCodeSuggestMessageSystemPrompt,
18+
generateCodeSuggestMessageUserPrompt,
19+
generateCommitMessageSystemPrompt,
20+
generateCommitMessageUserPrompt,
21+
} from './prompts';
22+
23+
const provider = { id: 'xai', name: 'xAI' } as const;
24+
25+
type xAIModel = AIModel<typeof provider.id>;
26+
const models: xAIModel[] = [
27+
{
28+
id: 'grok-beta',
29+
name: 'Grok Beta',
30+
maxTokens: 131072,
31+
provider: provider,
32+
default: true,
33+
},
34+
];
35+
36+
export class xAIProvider implements AIProvider<typeof provider.id> {
37+
readonly id = provider.id;
38+
readonly name = provider.name;
39+
40+
constructor(private readonly container: Container) {}
41+
42+
dispose() {}
43+
44+
getModels(): Promise<readonly AIModel<typeof provider.id>[]> {
45+
return Promise.resolve(models);
46+
}
47+
48+
async generateMessage(
49+
model: xAIModel,
50+
diff: string,
51+
reporting: TelemetryEvents['ai/generate'],
52+
promptConfig: {
53+
type: 'commit' | 'cloud-patch' | 'code-suggestion';
54+
systemPrompt: string;
55+
userPrompt: string;
56+
customInstructions?: string;
57+
},
58+
options?: { cancellation?: CancellationToken; context?: string },
59+
): Promise<string | undefined> {
60+
const apiKey = await getApiKey(this.container.storage);
61+
if (apiKey == null) return undefined;
62+
63+
let retries = 0;
64+
let maxCodeCharacters = getMaxCharacters(model, 2600);
65+
while (true) {
66+
const request: xAIChatCompletionRequest = {
67+
model: model.id,
68+
messages: [
69+
{
70+
role: 'system',
71+
content: promptConfig.systemPrompt,
72+
},
73+
{
74+
role: 'user',
75+
content: interpolate(promptConfig.userPrompt, {
76+
diff: diff.substring(0, maxCodeCharacters),
77+
context: options?.context ?? '',
78+
instructions: promptConfig.customInstructions ?? '',
79+
}),
80+
},
81+
],
82+
};
83+
84+
reporting['retry.count'] = retries;
85+
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length);
86+
87+
const rsp = await this.fetch(apiKey, request, options?.cancellation);
88+
if (!rsp.ok) {
89+
if (rsp.status === 404) {
90+
throw new Error(
91+
`Unable to generate ${promptConfig.type} message: Your API key doesn't seem to have access to the selected '${model.id}' model`,
92+
);
93+
}
94+
if (rsp.status === 429) {
95+
throw new Error(
96+
`Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`,
97+
);
98+
}
99+
100+
let json;
101+
try {
102+
json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined;
103+
} catch {}
104+
105+
debugger;
106+
107+
if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') {
108+
maxCodeCharacters -= 500 * retries;
109+
continue;
110+
}
111+
112+
throw new Error(
113+
`Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) ${
114+
json?.error?.message || rsp.statusText
115+
}`,
116+
);
117+
}
118+
119+
if (diff.length > maxCodeCharacters) {
120+
void window.showWarningMessage(
121+
`The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the xAI's limits.`,
122+
);
123+
}
124+
125+
const data: xAIChatCompletionResponse = await rsp.json();
126+
const message = data.choices[0].message.content.trim();
127+
return message;
128+
}
129+
}
130+
131+
async generateDraftMessage(
132+
model: xAIModel,
133+
diff: string,
134+
reporting: TelemetryEvents['ai/generate'],
135+
options?: {
136+
cancellation?: CancellationToken;
137+
context?: string;
138+
codeSuggestion?: boolean | undefined;
139+
},
140+
): Promise<string | undefined> {
141+
let codeSuggestion;
142+
if (options != null) {
143+
({ codeSuggestion, ...options } = options ?? {});
144+
}
145+
146+
return this.generateMessage(
147+
model,
148+
diff,
149+
reporting,
150+
codeSuggestion
151+
? {
152+
type: 'code-suggestion',
153+
systemPrompt: generateCodeSuggestMessageSystemPrompt,
154+
userPrompt: generateCodeSuggestMessageUserPrompt,
155+
customInstructions: configuration.get('experimental.generateCodeSuggestionMessagePrompt'),
156+
}
157+
: {
158+
type: 'cloud-patch',
159+
systemPrompt: generateCloudPatchMessageSystemPrompt,
160+
userPrompt: generateCloudPatchMessageUserPrompt,
161+
customInstructions: configuration.get('experimental.generateCloudPatchMessagePrompt'),
162+
},
163+
options,
164+
);
165+
}
166+
167+
async generateCommitMessage(
168+
model: xAIModel,
169+
diff: string,
170+
reporting: TelemetryEvents['ai/generate'],
171+
options?: { cancellation?: CancellationToken; context?: string },
172+
): Promise<string | undefined> {
173+
return this.generateMessage(
174+
model,
175+
diff,
176+
reporting,
177+
{
178+
type: 'commit',
179+
systemPrompt: generateCommitMessageSystemPrompt,
180+
userPrompt: generateCommitMessageUserPrompt,
181+
customInstructions: configuration.get('experimental.generateCommitMessagePrompt'),
182+
},
183+
options,
184+
);
185+
}
186+
187+
async explainChanges(
188+
model: xAIModel,
189+
message: string,
190+
diff: string,
191+
reporting: TelemetryEvents['ai/explain'],
192+
options?: { cancellation?: CancellationToken },
193+
): Promise<string | undefined> {
194+
const apiKey = await getApiKey(this.container.storage);
195+
if (apiKey == null) return undefined;
196+
197+
let retries = 0;
198+
let maxCodeCharacters = getMaxCharacters(model, 3000);
199+
while (true) {
200+
const code = diff.substring(0, maxCodeCharacters);
201+
202+
const request: xAIChatCompletionRequest = {
203+
model: model.id,
204+
messages: [
205+
{
206+
role: 'system',
207+
content: `You are an advanced AI programming assistant tasked with summarizing code changes into an explanation that is both easy to understand and meaningful. Construct an explanation that:
208+
- Concisely synthesizes meaningful information from the provided code diff
209+
- Incorporates any additional context provided by the user to understand the rationale behind the code changes
210+
- Places the emphasis on the 'why' of the change, clarifying its benefits or addressing the problem that necessitated the change, beyond just detailing the 'what' has changed
211+
212+
Do not make any assumptions or invent details that are not supported by the code diff or the user-provided context.`,
213+
},
214+
{
215+
role: 'user',
216+
content: `Here is additional context provided by the author of the changes, which should provide some explanation to why these changes where made. Please strongly consider this information when generating your explanation:\n\n${message}`,
217+
},
218+
{
219+
role: 'user',
220+
content: `Now, kindly explain the following code diff in a way that would be clear to someone reviewing or trying to understand these changes:\n\n${code}`,
221+
},
222+
{
223+
role: 'user',
224+
content:
225+
'Remember to frame your explanation in a way that is suitable for a reviewer to quickly grasp the essence of the changes, the issues they resolve, and their implications on the codebase.',
226+
},
227+
],
228+
};
229+
230+
reporting['retry.count'] = retries;
231+
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length);
232+
233+
const rsp = await this.fetch(apiKey, request, options?.cancellation);
234+
if (!rsp.ok) {
235+
if (rsp.status === 404) {
236+
throw new Error(
237+
`Unable to explain changes: Your API key doesn't seem to have access to the selected '${model.id}' model`,
238+
);
239+
}
240+
if (rsp.status === 429) {
241+
throw new Error(
242+
`Unable to explain changes: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`,
243+
);
244+
}
245+
246+
let json;
247+
try {
248+
json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined;
249+
} catch {}
250+
251+
debugger;
252+
253+
if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') {
254+
maxCodeCharacters -= 500 * retries;
255+
continue;
256+
}
257+
258+
throw new Error(
259+
`Unable to explain changes: (${this.name}:${rsp.status}) ${json?.error?.message || rsp.statusText}`,
260+
);
261+
}
262+
263+
if (diff.length > maxCodeCharacters) {
264+
void window.showWarningMessage(
265+
`The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the xAI's limits.`,
266+
);
267+
}
268+
269+
const data: xAIChatCompletionResponse = await rsp.json();
270+
const summary = data.choices[0].message.content.trim();
271+
return summary;
272+
}
273+
}
274+
275+
private async fetch(
276+
apiKey: string,
277+
request: xAIChatCompletionRequest,
278+
cancellation: CancellationToken | undefined,
279+
) {
280+
let aborter: AbortController | undefined;
281+
if (cancellation != null) {
282+
aborter = new AbortController();
283+
cancellation.onCancellationRequested(() => aborter?.abort());
284+
}
285+
286+
try {
287+
return await fetch('https://api.x.ai/v1/chat/completions', {
288+
headers: {
289+
Accept: 'application/json',
290+
Authorization: `Bearer ${apiKey}`,
291+
'Content-Type': 'application/json',
292+
},
293+
method: 'POST',
294+
body: JSON.stringify(request),
295+
signal: aborter?.signal,
296+
});
297+
} catch (ex) {
298+
if (ex.name === 'AbortError') throw new CancellationError(ex);
299+
300+
throw ex;
301+
}
302+
}
303+
}
304+
305+
async function getApiKey(storage: Storage): Promise<string | undefined> {
306+
return getApiKeyCore(storage, {
307+
id: provider.id,
308+
name: provider.name,
309+
validator: v => /(?:sk-)?[a-zA-Z0-9]{32,}/.test(v),
310+
url: 'https://console.x.ai/',
311+
});
312+
}
313+
314+
// eslint-disable-next-line @typescript-eslint/naming-convention
315+
interface xAIChatCompletionRequest {
316+
model: xAIModels;
317+
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
318+
temperature?: number;
319+
top_p?: number;
320+
n?: number;
321+
stream?: boolean;
322+
stop?: string | string[];
323+
max_tokens?: number;
324+
presence_penalty?: number;
325+
frequency_penalty?: number;
326+
logit_bias?: Record<string, number>;
327+
user?: string;
328+
}
329+
330+
// eslint-disable-next-line @typescript-eslint/naming-convention
331+
interface xAIChatCompletionResponse {
332+
id: string;
333+
object: 'chat.completion';
334+
created: number;
335+
model: string;
336+
choices: {
337+
index: number;
338+
message: {
339+
role: 'system' | 'user' | 'assistant';
340+
content: string;
341+
};
342+
finish_reason: string;
343+
}[];
344+
usage: {
345+
prompt_tokens: number;
346+
completion_tokens: number;
347+
total_tokens: number;
348+
};
349+
}

0 commit comments

Comments
 (0)