11import type { CancellationToken , Disposable , MessageItem , ProgressOptions , QuickInputButton } from 'vscode' ;
22import { env , ThemeIcon , Uri , window } from 'vscode' ;
3- import type { AIModels , AIProviders , SupportedAIModels , VSCodeAIModels } from '../constants.ai' ;
3+ import type { AIProviders , SupportedAIModels , VSCodeAIModels } from '../constants.ai' ;
44import type { AIGenerateDraftEventData , Sources , TelemetryEvents } from '../constants.telemetry' ;
55import type { Container } from '../container' ;
66import { CancellationError } from '../errors' ;
@@ -33,10 +33,7 @@ export interface AIResult {
3333 body : string ;
3434}
3535
36- export interface AIModel <
37- Provider extends AIProviders = AIProviders ,
38- Model extends AIModels < Provider > = AIModels < Provider > ,
39- > {
36+ export interface AIModel < Provider extends AIProviders = AIProviders , Model extends string = string > {
4037 readonly id : Model ;
4138 readonly name : string ;
4239 readonly maxTokens : { input : number ; output : number } ;
@@ -55,38 +52,39 @@ interface AIProviderConstructor<Provider extends AIProviders = AIProviders> {
5552 new ( container : Container ) : AIProvider < Provider > ;
5653}
5754
55+ // Order matters for sorting the picker
5856const _supportedProviderTypes = new Map < AIProviders , AIProviderConstructor > ( [
5957 ...( supportedInVSCodeVersion ( 'language-models' ) ? [ [ 'vscode' , VSCodeAIProvider ] ] : ( [ ] as any ) ) ,
6058 [ 'openai' , OpenAIProvider ] ,
6159 [ 'anthropic' , AnthropicProvider ] ,
62- [ 'deepseek' , DeepSeekProvider ] ,
6360 [ 'gemini' , GeminiProvider ] ,
61+ [ 'deepseek' , DeepSeekProvider ] ,
62+ [ 'xai' , xAIProvider ] ,
6463 [ 'github' , GitHubModelsProvider ] ,
6564 [ 'huggingface' , HuggingFaceProvider ] ,
66- [ 'xai' , xAIProvider ] ,
6765] ) ;
6866
6967export interface AIProvider < Provider extends AIProviders = AIProviders > extends Disposable {
7068 readonly id : Provider ;
7169 readonly name : string ;
7270
73- getModels ( ) : Promise < readonly AIModel < Provider , AIModels < Provider > > [ ] > ;
71+ getModels ( ) : Promise < readonly AIModel < Provider > [ ] > ;
7472
7573 explainChanges (
76- model : AIModel < Provider , AIModels < Provider > > ,
74+ model : AIModel < Provider > ,
7775 message : string ,
7876 diff : string ,
7977 reporting : TelemetryEvents [ 'ai/explain' ] ,
8078 options ?: { cancellation ?: CancellationToken } ,
8179 ) : Promise < string | undefined > ;
8280 generateCommitMessage (
83- model : AIModel < Provider , AIModels < Provider > > ,
81+ model : AIModel < Provider > ,
8482 diff : string ,
8583 reporting : TelemetryEvents [ 'ai/generate' ] ,
8684 options ?: { cancellation ?: CancellationToken ; context ?: string } ,
8785 ) : Promise < string | undefined > ;
8886 generateDraftMessage (
89- model : AIModel < Provider , AIModels < Provider > > ,
87+ model : AIModel < Provider > ,
9088 diff : string ,
9189 reporting : TelemetryEvents [ 'ai/generate' ] ,
9290 options ?: { cancellation ?: CancellationToken ; context ?: string ; codeSuggestion ?: boolean } ,
@@ -107,10 +105,10 @@ export class AIProviderService implements Disposable {
107105 return this . _provider ?. id ;
108106 }
109107
110- private getConfiguredModel ( ) : { provider : AIProviders ; model : AIModels } | undefined {
108+ private getConfiguredModel ( ) : { provider : AIProviders ; model : string } | undefined {
111109 const qualifiedModelId = configuration . get ( 'ai.model' ) ?? undefined ;
112110 if ( qualifiedModelId != null ) {
113- let [ providerId , modelId ] = qualifiedModelId . split ( ':' ) as [ AIProviders , AIModels ] ;
111+ let [ providerId , modelId ] = qualifiedModelId . split ( ':' ) as [ AIProviders , string ] ;
114112 if ( providerId != null && this . supports ( providerId ) ) {
115113 if ( modelId != null ) {
116114 return { provider : providerId , model : modelId } ;
@@ -150,10 +148,10 @@ export class AIProviderService implements Disposable {
150148 }
151149
152150 private getOrUpdateModel ( model : AIModel ) : Promise < AIModel | undefined > ;
153- private getOrUpdateModel < T extends AIProviders > ( providerId : T , modelId : AIModels < T > ) : Promise < AIModel | undefined > ;
151+ private getOrUpdateModel < T extends AIProviders > ( providerId : T , modelId : string ) : Promise < AIModel | undefined > ;
154152 private async getOrUpdateModel (
155153 modelOrProviderId : AIModel | AIProviders ,
156- modelId ?: AIModels ,
154+ modelId ?: string ,
157155 ) : Promise < AIModel | undefined > {
158156 let providerId : AIProviders ;
159157 let model : AIModel | undefined ;
@@ -552,7 +550,7 @@ export class AIProviderService implements Disposable {
552550
553551async function confirmAIProviderToS < Provider extends AIProviders > (
554552 service : AIProviderService ,
555- model : AIModel < Provider , AIModels < Provider > > ,
553+ model : AIModel < Provider > ,
556554 storage : Storage ,
557555) : Promise < boolean > {
558556 const confirmed =
@@ -596,9 +594,9 @@ async function confirmAIProviderToS<Provider extends AIProviders>(
596594 return false ;
597595}
598596
599- export function getMaxCharacters ( model : AIModel , outputLength : number ) : number {
597+ export function getMaxCharacters ( model : AIModel , outputLength : number , overrideInputTokens ?: number ) : number {
600598 const tokensPerCharacter = 3.1 ;
601- const max = model . maxTokens . input * tokensPerCharacter - outputLength / tokensPerCharacter ;
599+ const max = ( overrideInputTokens ?? model . maxTokens . input ) * tokensPerCharacter - outputLength / tokensPerCharacter ;
602600 return Math . floor ( max - max * 0.1 ) ;
603601}
604602
0 commit comments