Skip to content

Commit 51d03bd

Browse files
committed
Adds support for Hugging Chat models
1 parent 95ea55a commit 51d03bd

File tree

3 files changed

+381
-6
lines changed

3 files changed

+381
-6
lines changed

src/ai/aiProviderService.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import { supportedInVSCodeVersion } from '../system/vscode/utils';
1818
import type { TelemetryService } from '../telemetry/telemetry';
1919
import { AnthropicProvider } from './anthropicProvider';
2020
import { GeminiProvider } from './geminiProvider';
21+
import { HuggingChatProvider } from './huggingchatProvider';
2122
import { OpenAIProvider } from './openaiProvider';
2223
import { isVSCodeAIModel, VSCodeAIProvider } from './vscodeProvider';
2324
import { xAIProvider } from './xaiProvider';
@@ -47,6 +48,7 @@ const _supportedProviderTypes = new Map<AIProviders, AIProviderConstructor>([
4748
['openai', OpenAIProvider],
4849
['anthropic', AnthropicProvider],
4950
['gemini', GeminiProvider],
51+
['huggingchat', HuggingChatProvider],
5052
['xai', xAIProvider],
5153
]);
5254

src/ai/huggingchatProvider.ts

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import { fetch } from '@env/fetch';
2+
import type { CancellationToken } from 'vscode';
3+
import { window } from 'vscode';
4+
import type { HuggingChatModels } from '../constants.ai';
5+
import type { TelemetryEvents } from '../constants.telemetry';
6+
import type { Container } from '../container';
7+
import { CancellationError } from '../errors';
8+
import { sum } from '../system/iterable';
9+
import { interpolate } from '../system/string';
10+
import { configuration } from '../system/vscode/configuration';
11+
import type { Storage } from '../system/vscode/storage';
12+
import type { AIModel, AIProvider } from './aiProviderService';
13+
import { getApiKey as getApiKeyCore, getMaxCharacters } from './aiProviderService';
14+
import {
15+
generateCloudPatchMessageSystemPrompt,
16+
generateCloudPatchMessageUserPrompt,
17+
generateCodeSuggestMessageSystemPrompt,
18+
generateCodeSuggestMessageUserPrompt,
19+
generateCommitMessageSystemPrompt,
20+
generateCommitMessageUserPrompt,
21+
} from './prompts';
22+
23+
const provider = { id: 'huggingchat', name: 'Hugging Chat' } as const;
24+
25+
type HuggingChatModel = AIModel<typeof provider.id>;
26+
const models: HuggingChatModel[] = [
27+
{
28+
id: 'google/gemma-1.1-2b-it',
29+
name: 'Google Gemma 1.1 2B',
30+
maxTokens: 4096,
31+
provider: provider,
32+
},
33+
{
34+
id: 'HuggingFaceH4/starchat2-15b-v0.1',
35+
name: 'HuggingFace Starchat 2.1',
36+
maxTokens: 4096,
37+
provider: provider,
38+
},
39+
{
40+
id: 'meta-llama/Llama-3.2-3B-Instruct',
41+
name: 'Meta Llama 3.1 8B',
42+
maxTokens: 4096,
43+
provider: provider,
44+
},
45+
{
46+
id: 'microsoft/Phi-3-mini-4k-instruct',
47+
name: 'Microsoft Phi 3 Mini',
48+
maxTokens: 4096,
49+
provider: provider,
50+
},
51+
{
52+
id: 'mistralai/Mistral-Nemo-Instruct-2407',
53+
name: 'Mistral Nemo Instruct',
54+
maxTokens: 4096,
55+
provider: provider,
56+
},
57+
];
58+
59+
export class HuggingChatProvider implements AIProvider<typeof provider.id> {
60+
readonly id = provider.id;
61+
readonly name = provider.name;
62+
63+
constructor(private readonly container: Container) {}
64+
65+
dispose() {}
66+
67+
getModels(): Promise<readonly AIModel<typeof provider.id>[]> {
68+
return Promise.resolve(models);
69+
}
70+
71+
async generateMessage(
72+
model: HuggingChatModel,
73+
diff: string,
74+
reporting: TelemetryEvents['ai/generate'],
75+
promptConfig: {
76+
type: 'commit' | 'cloud-patch' | 'code-suggestion';
77+
systemPrompt: string;
78+
userPrompt: string;
79+
customInstructions?: string;
80+
},
81+
options?: { cancellation?: CancellationToken; context?: string },
82+
): Promise<string | undefined> {
83+
const apiKey = await getApiKey(this.container.storage);
84+
if (apiKey == null) return undefined;
85+
86+
let retries = 0;
87+
let maxCodeCharacters = getMaxCharacters(model, 2600);
88+
while (true) {
89+
const request: HuggingChatChatCompletionRequest = {
90+
model: model.id,
91+
messages: [
92+
{
93+
role: 'system',
94+
content: promptConfig.systemPrompt,
95+
},
96+
{
97+
role: 'user',
98+
content: interpolate(promptConfig.userPrompt, {
99+
diff: diff.substring(0, maxCodeCharacters),
100+
context: options?.context ?? '',
101+
instructions: promptConfig.customInstructions ?? '',
102+
}),
103+
},
104+
],
105+
};
106+
107+
reporting['retry.count'] = retries;
108+
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length);
109+
110+
const rsp = await this.fetch(apiKey, request, options?.cancellation);
111+
if (!rsp.ok) {
112+
if (rsp.status === 404) {
113+
throw new Error(
114+
`Unable to generate ${promptConfig.type} message: Your API key doesn't seem to have access to the selected '${model.id}' model`,
115+
);
116+
}
117+
if (rsp.status === 429) {
118+
throw new Error(
119+
`Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`,
120+
);
121+
}
122+
123+
let json;
124+
try {
125+
json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined;
126+
} catch {}
127+
128+
debugger;
129+
130+
if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') {
131+
maxCodeCharacters -= 500 * retries;
132+
continue;
133+
}
134+
135+
throw new Error(
136+
`Unable to generate ${promptConfig.type} message: (${this.name}:${rsp.status}) ${
137+
json?.error?.message || rsp.statusText
138+
}`,
139+
);
140+
}
141+
142+
if (diff.length > maxCodeCharacters) {
143+
void window.showWarningMessage(
144+
`The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the Hugging Chat's limits.`,
145+
);
146+
}
147+
148+
const data: HuggingChatChatCompletionResponse = await rsp.json();
149+
const message = data.choices[0].message.content.trim();
150+
return message;
151+
}
152+
}
153+
154+
async generateDraftMessage(
155+
model: HuggingChatModel,
156+
diff: string,
157+
reporting: TelemetryEvents['ai/generate'],
158+
options?: {
159+
cancellation?: CancellationToken;
160+
context?: string;
161+
codeSuggestion?: boolean | undefined;
162+
},
163+
): Promise<string | undefined> {
164+
let codeSuggestion;
165+
if (options != null) {
166+
({ codeSuggestion, ...options } = options ?? {});
167+
}
168+
169+
return this.generateMessage(
170+
model,
171+
diff,
172+
reporting,
173+
codeSuggestion
174+
? {
175+
type: 'code-suggestion',
176+
systemPrompt: generateCodeSuggestMessageSystemPrompt,
177+
userPrompt: generateCodeSuggestMessageUserPrompt,
178+
customInstructions: configuration.get('experimental.generateCodeSuggestionMessagePrompt'),
179+
}
180+
: {
181+
type: 'cloud-patch',
182+
systemPrompt: generateCloudPatchMessageSystemPrompt,
183+
userPrompt: generateCloudPatchMessageUserPrompt,
184+
customInstructions: configuration.get('experimental.generateCloudPatchMessagePrompt'),
185+
},
186+
options,
187+
);
188+
}
189+
190+
async generateCommitMessage(
191+
model: HuggingChatModel,
192+
diff: string,
193+
reporting: TelemetryEvents['ai/generate'],
194+
options?: { cancellation?: CancellationToken; context?: string },
195+
): Promise<string | undefined> {
196+
return this.generateMessage(
197+
model,
198+
diff,
199+
reporting,
200+
{
201+
type: 'commit',
202+
systemPrompt: generateCommitMessageSystemPrompt,
203+
userPrompt: generateCommitMessageUserPrompt,
204+
customInstructions: configuration.get('experimental.generateCommitMessagePrompt'),
205+
},
206+
options,
207+
);
208+
}
209+
210+
async explainChanges(
211+
model: HuggingChatModel,
212+
message: string,
213+
diff: string,
214+
reporting: TelemetryEvents['ai/explain'],
215+
options?: { cancellation?: CancellationToken },
216+
): Promise<string | undefined> {
217+
const apiKey = await getApiKey(this.container.storage);
218+
if (apiKey == null) return undefined;
219+
220+
let retries = 0;
221+
let maxCodeCharacters = getMaxCharacters(model, 3000);
222+
while (true) {
223+
const code = diff.substring(0, maxCodeCharacters);
224+
225+
const request: HuggingChatChatCompletionRequest = {
226+
model: model.id,
227+
messages: [
228+
{
229+
role: 'user',
230+
content: `You are an advanced AI programming assistant tasked with summarizing code changes into an explanation that is both easy to understand and meaningful. Construct an explanation that:
231+
- Concisely synthesizes meaningful information from the provided code diff
232+
- Incorporates any additional context provided by the user to understand the rationale behind the code changes
233+
- Places the emphasis on the 'why' of the change, clarifying its benefits or addressing the problem that necessitated the change, beyond just detailing the 'what' has changed
234+
235+
Do not make any assumptions or invent details that are not supported by the code diff or the user-provided context.
236+
237+
Here is additional context provided by the author of the changes, which should provide some explanation to why these changes where made. Please strongly consider this information when generating your explanation:\n\n${message}
238+
239+
Now, kindly explain the following code diff in a way that would be clear to someone reviewing or trying to understand these changes:\n\n${code}
240+
241+
Remember to frame your explanation in a way that is suitable for a reviewer to quickly grasp the essence of the changes, the issues they resolve, and their implications on the codebase.`,
242+
},
243+
],
244+
};
245+
246+
reporting['retry.count'] = retries;
247+
reporting['input.length'] = (reporting['input.length'] ?? 0) + sum(request.messages, m => m.content.length);
248+
249+
const rsp = await this.fetch(apiKey, request, options?.cancellation);
250+
if (!rsp.ok) {
251+
if (rsp.status === 404) {
252+
throw new Error(
253+
`Unable to explain changes: Your API key doesn't seem to have access to the selected '${model.id}' model`,
254+
);
255+
}
256+
if (rsp.status === 429) {
257+
throw new Error(
258+
`Unable to explain changes: (${this.name}:${rsp.status}) Too many requests (rate limit exceeded) or your API key is associated with an expired trial`,
259+
);
260+
}
261+
262+
let json;
263+
try {
264+
json = (await rsp.json()) as { error?: { code: string; message: string } } | undefined;
265+
} catch {}
266+
267+
debugger;
268+
269+
if (retries++ < 2 && json?.error?.code === 'context_length_exceeded') {
270+
maxCodeCharacters -= 500 * retries;
271+
continue;
272+
}
273+
274+
throw new Error(
275+
`Unable to explain changes: (${this.name}:${rsp.status}) ${json?.error?.message || rsp.statusText}`,
276+
);
277+
}
278+
279+
if (diff.length > maxCodeCharacters) {
280+
void window.showWarningMessage(
281+
`The diff of the changes had to be truncated to ${maxCodeCharacters} characters to fit within the Hugging Chat's limits.`,
282+
);
283+
}
284+
285+
const data: HuggingChatChatCompletionResponse = await rsp.json();
286+
const summary = data.choices[0].message.content.trim();
287+
return summary;
288+
}
289+
}
290+
291+
private async fetch(
292+
apiKey: string,
293+
request: HuggingChatChatCompletionRequest,
294+
cancellation: CancellationToken | undefined,
295+
) {
296+
let aborter: AbortController | undefined;
297+
if (cancellation != null) {
298+
aborter = new AbortController();
299+
cancellation.onCancellationRequested(() => aborter?.abort());
300+
}
301+
302+
try {
303+
return await fetch(`https://api-inference.huggingface.co/models/${request.model}/v1/chat/completions`, {
304+
headers: {
305+
Accept: 'application/json',
306+
Authorization: `Bearer ${apiKey}`,
307+
'Content-Type': 'application/json',
308+
},
309+
method: 'POST',
310+
body: JSON.stringify(request),
311+
signal: aborter?.signal,
312+
});
313+
} catch (ex) {
314+
if (ex.name === 'AbortError') throw new CancellationError(ex);
315+
316+
throw ex;
317+
}
318+
}
319+
}
320+
321+
async function getApiKey(storage: Storage): Promise<string | undefined> {
322+
return getApiKeyCore(storage, {
323+
id: provider.id,
324+
name: provider.name,
325+
validator: v => /(?:sk-)?[a-zA-Z0-9]{32,}/.test(v),
326+
url: 'https://huggingface.co/settings/tokens',
327+
});
328+
}
329+
330+
interface HuggingChatChatCompletionRequest {
331+
model: HuggingChatModels;
332+
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
333+
temperature?: number;
334+
top_p?: number;
335+
n?: number;
336+
stream?: boolean;
337+
stop?: string | string[];
338+
max_tokens?: number;
339+
presence_penalty?: number;
340+
frequency_penalty?: number;
341+
logit_bias?: Record<string, number>;
342+
user?: string;
343+
}
344+
345+
interface HuggingChatChatCompletionResponse {
346+
id: string;
347+
object: 'chat.completion';
348+
created: number;
349+
model: string;
350+
choices: {
351+
index: number;
352+
message: {
353+
role: 'system' | 'user' | 'assistant';
354+
content: string;
355+
};
356+
finish_reason: string;
357+
}[];
358+
usage: {
359+
prompt_tokens: number;
360+
completion_tokens: number;
361+
total_tokens: number;
362+
};
363+
}

0 commit comments

Comments
 (0)