|
| 1 | +import { fetch } from '@env/fetch'; |
| 2 | +import type { CancellationToken } from 'vscode'; |
| 3 | +import { window } from 'vscode'; |
| 4 | +import type { xAIModels } from '../constants.ai'; |
| 5 | +import type { TelemetryEvents } from '../constants.telemetry'; |
| 6 | +import type { Container } from '../container'; |
| 7 | +import { CancellationError } from '../errors'; |
| 8 | +import { sum } from '../system/iterable'; |
| 9 | +import { interpolate } from '../system/string'; |
| 10 | +import { configuration } from '../system/vscode/configuration'; |
| 11 | +import type { Storage } from '../system/vscode/storage'; |
| 12 | +import type { AIModel, AIProvider } from './aiProviderService'; |
| 13 | +import { getApiKey as getApiKeyCore, getMaxCharacters } from './aiProviderService'; |
| 14 | +import { |
| 15 | + generateCloudPatchMessageSystemPrompt, |
| 16 | + generateCloudPatchMessageUserPrompt, |
| 17 | + generateCodeSuggestMessageSystemPrompt, |
| 18 | + generateCodeSuggestMessageUserPrompt, |
| 19 | + generateCommitMessageSystemPrompt, |
| 20 | + generateCommitMessageUserPrompt, |
| 21 | +} from './prompts'; |
| 22 | + |
| 23 | +const provider = { id: 'xai', name: 'xAI' } as const; |
| 24 | + |
| 25 | +type xAIModel = AIModel<typeof provider.id>; |
| 26 | +const models: xAIModel[] = [ |
| 27 | + { |
| 28 | + id: 'grok-beta', |
| 29 | + name: 'Grok Beta', |
| 30 | + maxTokens: 131072, |
| 31 | + provider: provider, |
| 32 | + default: true, |
| 33 | + }, |
| 34 | +]; |
| 35 | + |
| 36 | +export class xAIProvider implements AIProvider<typeof provider.id> { |
| 37 | + readonly id = provider.id; |
| 38 | + readonly name = provider.name; |
| 39 | + |
| 40 | + constructor(private readonly container: Container) {} |
| 41 | + |
| 42 | + dispose() {} |
| 43 | + |
| 44 | + getModels(): Promise<readonly AIModel<typeof provider.id>[]> { |
| 45 | + return Promise.resolve(models); |
| 46 | + } |
| 47 | + |
| 48 | + async generateMessage( |
| 49 | + model: xAIModel, |
| 50 | + diff: string, |
| 51 | + reporting: TelemetryEvents['ai/generate'], |
| 52 | + promptConfig: { |
| 53 | + type: 'commit' | 'cloud-patch' | 'code-suggestion'; |
| 54 | + systemPrompt: string; |
| 55 | + userPrompt: string; |
| 56 | + customInstructions?: string; |
| 57 | + }, |
| 58 | + options?: { cancellation?: CancellationToken; context?: string }, |
| 59 | + ): Promise<string | undefined> { |
| 60 | + const apiKey = await getApiKey(this.container.storage); |
| 61 | + if (apiKey == null) return undefined; |
| 62 | + |
| 63 | + let retries = 0; |
| 64 | + let maxCodeCharacters = getMaxCharacters(model, 2600); |
| 65 | + while (true) { |
| 66 | + const request: xAIChatCompletionRequest = { |
| 67 | + model: model.id, |
| 68 | + messages: [ |
| 69 | + { |
| 70 | + role: 'system', |
| 71 | + content: promptConfig.systemPrompt, |
| 72 | + }, |
| 73 | + { |
| 74 | + role: 'user', |
| 75 | + content: interpolate(promptConfig.userPrompt, { |
| 76 | + diff: diff.substring(0, maxCodeCharacters), |
| 77 | + context: options?.context ?? '', |
| 78 | + instructions: promptConfig.customInstructions ?? '', |
| 79 | + }), |
| 80 | + }, |
| 81 | + ], |
| 82 | + }; |
| 83 | + |
| 84 | + reporting['retry.count'] = retries; |
| 85 | + reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length); |
| 86 | + |
| 87 | + const rsp = await this.fetch(apiKey, request, options?.cancellation); |
| 88 | + if (!rsp.ok) { |
| 89 | + if (rsp.status === 404) { |
| 90 | + throw new Error( |
| 91 | + `Unable to generate ${promptConfig.type} message: Your API key doesn't seem to have access to the selected '${model.id}' model`, |
| 92 | + ); |
| 93 | + } |
| 94 | + if (rsp.status === 429) { |
| 95 | + throw new Error( |
| 96 | + `Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`, |
| 97 | + ); |
| 98 | + } |
| 99 | + |
| 100 | + let json; |
| 101 | + try { |
| 102 | + json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined; |
| 103 | + } catch {} |
| 104 | + |
| 105 | + debugger; |
| 106 | + |
| 107 | + if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') { |
| 108 | + maxCodeCharacters -= 500 * retries; |
| 109 | + continue; |
| 110 | + } |
| 111 | + |
| 112 | + throw new Error( |
| 113 | + `Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) ${ |
| 114 | + json?.error?.message || rsp.statusText |
| 115 | + }`, |
| 116 | + ); |
| 117 | + } |
| 118 | + |
| 119 | + if (diff.length > maxCodeCharacters) { |
| 120 | + void window.showWarningMessage( |
| 121 | + `The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the xAI's limits.`, |
| 122 | + ); |
| 123 | + } |
| 124 | + |
| 125 | + const data: xAIChatCompletionResponse = await rsp.json(); |
| 126 | + const message = data.choices[0].message.content.trim(); |
| 127 | + return message; |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + async generateDraftMessage( |
| 132 | + model: xAIModel, |
| 133 | + diff: string, |
| 134 | + reporting: TelemetryEvents['ai/generate'], |
| 135 | + options?: { |
| 136 | + cancellation?: CancellationToken; |
| 137 | + context?: string; |
| 138 | + codeSuggestion?: boolean | undefined; |
| 139 | + }, |
| 140 | + ): Promise<string | undefined> { |
| 141 | + let codeSuggestion; |
| 142 | + if (options != null) { |
| 143 | + ({ codeSuggestion, ...options } = options ?? {}); |
| 144 | + } |
| 145 | + |
| 146 | + return this.generateMessage( |
| 147 | + model, |
| 148 | + diff, |
| 149 | + reporting, |
| 150 | + codeSuggestion |
| 151 | + ? { |
| 152 | + type: 'code-suggestion', |
| 153 | + systemPrompt: generateCodeSuggestMessageSystemPrompt, |
| 154 | + userPrompt: generateCodeSuggestMessageUserPrompt, |
| 155 | + customInstructions: configuration.get('experimental.generateCodeSuggestionMessagePrompt'), |
| 156 | + } |
| 157 | + : { |
| 158 | + type: 'cloud-patch', |
| 159 | + systemPrompt: generateCloudPatchMessageSystemPrompt, |
| 160 | + userPrompt: generateCloudPatchMessageUserPrompt, |
| 161 | + customInstructions: configuration.get('experimental.generateCloudPatchMessagePrompt'), |
| 162 | + }, |
| 163 | + options, |
| 164 | + ); |
| 165 | + } |
| 166 | + |
| 167 | + async generateCommitMessage( |
| 168 | + model: xAIModel, |
| 169 | + diff: string, |
| 170 | + reporting: TelemetryEvents['ai/generate'], |
| 171 | + options?: { cancellation?: CancellationToken; context?: string }, |
| 172 | + ): Promise<string | undefined> { |
| 173 | + return this.generateMessage( |
| 174 | + model, |
| 175 | + diff, |
| 176 | + reporting, |
| 177 | + { |
| 178 | + type: 'commit', |
| 179 | + systemPrompt: generateCommitMessageSystemPrompt, |
| 180 | + userPrompt: generateCommitMessageUserPrompt, |
| 181 | + customInstructions: configuration.get('experimental.generateCommitMessagePrompt'), |
| 182 | + }, |
| 183 | + options, |
| 184 | + ); |
| 185 | + } |
| 186 | + |
| 187 | + async explainChanges( |
| 188 | + model: xAIModel, |
| 189 | + message: string, |
| 190 | + diff: string, |
| 191 | + reporting: TelemetryEvents['ai/explain'], |
| 192 | + options?: { cancellation?: CancellationToken }, |
| 193 | + ): Promise<string | undefined> { |
| 194 | + const apiKey = await getApiKey(this.container.storage); |
| 195 | + if (apiKey == null) return undefined; |
| 196 | + |
| 197 | + let retries = 0; |
| 198 | + let maxCodeCharacters = getMaxCharacters(model, 3000); |
| 199 | + while (true) { |
| 200 | + const code = diff.substring(0, maxCodeCharacters); |
| 201 | + |
| 202 | + const request: xAIChatCompletionRequest = { |
| 203 | + model: model.id, |
| 204 | + messages: [ |
| 205 | + { |
| 206 | + role: 'system', |
| 207 | + content: `You are an advanced AI programming assistant tasked with summarizing code changes into an explanation that is both easy to understand and meaningful. Construct an explanation that: |
| 208 | +- Concisely synthesizes meaningful information from the provided code diff |
| 209 | +- Incorporates any additional context provided by the user to understand the rationale behind the code changes |
| 210 | +- Places the emphasis on the 'why' of the change, clarifying its benefits or addressing the problem that necessitated the change, beyond just detailing the 'what' has changed |
| 211 | +
|
| 212 | +Do not make any assumptions or invent details that are not supported by the code diff or the user-provided context.`, |
| 213 | + }, |
| 214 | + { |
| 215 | + role: 'user', |
| 216 | + content: `Here is additional context provided by the author of the changes, which should provide some explanation to why these changes where made. Please strongly consider this information when generating your explanation:\n\n${message}`, |
| 217 | + }, |
| 218 | + { |
| 219 | + role: 'user', |
| 220 | + content: `Now, kindly explain the following code diff in a way that would be clear to someone reviewing or trying to understand these changes:\n\n${code}`, |
| 221 | + }, |
| 222 | + { |
| 223 | + role: 'user', |
| 224 | + content: |
| 225 | + 'Remember to frame your explanation in a way that is suitable for a reviewer to quickly grasp the essence of the changes, the issues they resolve, and their implications on the codebase.', |
| 226 | + }, |
| 227 | + ], |
| 228 | + }; |
| 229 | + |
| 230 | + reporting['retry.count'] = retries; |
| 231 | + reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length); |
| 232 | + |
| 233 | + const rsp = await this.fetch(apiKey, request, options?.cancellation); |
| 234 | + if (!rsp.ok) { |
| 235 | + if (rsp.status === 404) { |
| 236 | + throw new Error( |
| 237 | + `Unable to explain changes: Your API key doesn't seem to have access to the selected '${model.id}' model`, |
| 238 | + ); |
| 239 | + } |
| 240 | + if (rsp.status === 429) { |
| 241 | + throw new Error( |
| 242 | + `Unable to explain changes: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`, |
| 243 | + ); |
| 244 | + } |
| 245 | + |
| 246 | + let json; |
| 247 | + try { |
| 248 | + json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined; |
| 249 | + } catch {} |
| 250 | + |
| 251 | + debugger; |
| 252 | + |
| 253 | + if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') { |
| 254 | + maxCodeCharacters -= 500 * retries; |
| 255 | + continue; |
| 256 | + } |
| 257 | + |
| 258 | + throw new Error( |
| 259 | + `Unable to explain changes: (${this.name}:${rsp.status}) ${json?.error?.message || rsp.statusText}`, |
| 260 | + ); |
| 261 | + } |
| 262 | + |
| 263 | + if (diff.length > maxCodeCharacters) { |
| 264 | + void window.showWarningMessage( |
| 265 | + `The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the xAI's limits.`, |
| 266 | + ); |
| 267 | + } |
| 268 | + |
| 269 | + const data: xAIChatCompletionResponse = await rsp.json(); |
| 270 | + const summary = data.choices[0].message.content.trim(); |
| 271 | + return summary; |
| 272 | + } |
| 273 | + } |
| 274 | + |
| 275 | + private async fetch( |
| 276 | + apiKey: string, |
| 277 | + request: xAIChatCompletionRequest, |
| 278 | + cancellation: CancellationToken | undefined, |
| 279 | + ) { |
| 280 | + let aborter: AbortController | undefined; |
| 281 | + if (cancellation != null) { |
| 282 | + aborter = new AbortController(); |
| 283 | + cancellation.onCancellationRequested(() => aborter?.abort()); |
| 284 | + } |
| 285 | + |
| 286 | + try { |
| 287 | + return await fetch('https://api.x.ai/v1/chat/completions', { |
| 288 | + headers: { |
| 289 | + Accept: 'application/json', |
| 290 | + Authorization: `Bearer ${apiKey}`, |
| 291 | + 'Content-Type': 'application/json', |
| 292 | + }, |
| 293 | + method: 'POST', |
| 294 | + body: JSON.stringify(request), |
| 295 | + signal: aborter?.signal, |
| 296 | + }); |
| 297 | + } catch (ex) { |
| 298 | + if (ex.name === 'AbortError') throw new CancellationError(ex); |
| 299 | + |
| 300 | + throw ex; |
| 301 | + } |
| 302 | + } |
| 303 | +} |
| 304 | + |
| 305 | +async function getApiKey(storage: Storage): Promise<string | undefined> { |
| 306 | + return getApiKeyCore(storage, { |
| 307 | + id: provider.id, |
| 308 | + name: provider.name, |
| 309 | + validator: v => /(?:sk-)?[a-zA-Z0-9]{32,}/.test(v), |
| 310 | + url: 'https://console.x.ai/', |
| 311 | + }); |
| 312 | +} |
| 313 | + |
| 314 | +// eslint-disable-next-line @typescript-eslint/naming-convention |
| 315 | +interface xAIChatCompletionRequest { |
| 316 | + model: xAIModels; |
| 317 | + messages: { role: 'system' | 'user' | 'assistant'; content: string }[]; |
| 318 | + temperature?: number; |
| 319 | + top_p?: number; |
| 320 | + n?: number; |
| 321 | + stream?: boolean; |
| 322 | + stop?: string | string[]; |
| 323 | + max_tokens?: number; |
| 324 | + presence_penalty?: number; |
| 325 | + frequency_penalty?: number; |
| 326 | + logit_bias?: Record<string, number>; |
| 327 | + user?: string; |
| 328 | +} |
| 329 | + |
| 330 | +// eslint-disable-next-line @typescript-eslint/naming-convention |
| 331 | +interface xAIChatCompletionResponse { |
| 332 | + id: string; |
| 333 | + object: 'chat.completion'; |
| 334 | + created: number; |
| 335 | + model: string; |
| 336 | + choices: { |
| 337 | + index: number; |
| 338 | + message: { |
| 339 | + role: 'system' | 'user' | 'assistant'; |
| 340 | + content: string; |
| 341 | + }; |
| 342 | + finish_reason: string; |
| 343 | + }[]; |
| 344 | + usage: { |
| 345 | + prompt_tokens: number; |
| 346 | + completion_tokens: number; |
| 347 | + total_tokens: number; |
| 348 | + }; |
| 349 | +} |
0 commit comments