1+ use amzn_codewhisperer_client:: types:: Model ;
12use clap:: Args ;
23use crossterm:: style:: {
34 self ,
@@ -20,35 +21,6 @@ use crate::cli::chat::{
2021 ChatState ,
2122} ;
2223use crate :: os:: Os ;
23-
24- pub struct ModelOption {
25- /// Display name
26- pub name : & ' static str ,
27- /// Actual model id to send in the API
28- pub model_id : & ' static str ,
29- /// Size of the model's context window, in tokens
30- pub context_window_tokens : usize ,
31- }
32-
33- const MODEL_OPTIONS : [ ModelOption ; 2 ] = [
34- ModelOption {
35- name : "claude-4-sonnet" ,
36- model_id : "CLAUDE_SONNET_4_20250514_V1_0" ,
37- context_window_tokens : 200_000 ,
38- } ,
39- ModelOption {
40- name : "claude-3.7-sonnet" ,
41- model_id : "CLAUDE_3_7_SONNET_20250219_V1_0" ,
42- context_window_tokens : 200_000 ,
43- } ,
44- ] ;
45-
46- const GPT_OSS_120B : ModelOption = ModelOption {
47- name : "openai-gpt-oss-120b-preview" ,
48- model_id : "OPENAI_GPT_OSS_120B_1_0" ,
49- context_window_tokens : 128_000 ,
50- } ;
51-
5224#[ deny( missing_docs) ]
5325#[ derive( Debug , PartialEq , Args ) ]
5426pub struct ModelArgs ;
@@ -65,11 +37,7 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
6537 queue ! ( session. stderr, style:: Print ( "\n " ) ) ?;
6638
6739 // Fetch available models from service
68- let ( models, _default_model) = os
69- . client
70- . list_available_models_cached ( )
71- . await
72- . map_err ( |e| ChatError :: Custom ( format ! ( "Failed to fetch available models: {}" , e) . into ( ) ) ) ?;
40+ let ( models, _default_model) = get_available_models ( os) . await ?;
7341
7442 if models. is_empty ( ) {
7543 queue ! (
@@ -82,15 +50,16 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
8250 }
8351
8452 let active_model_id = session. conversation . model . as_deref ( ) ;
85- let model_options = get_model_options ( os) . await ?;
8653
87- let labels: Vec < String > = model_options
54+ let labels: Vec < String > = models
8855 . iter ( )
8956 . map ( |model| {
57+ let display_name = model. model_name ( ) . unwrap_or ( model. model_id ( ) ) ;
58+
9059 if Some ( model. model_id ( ) ) == active_model_id {
91- format ! ( "{} (active)" , model . model_id ( ) )
60+ format ! ( "{} (active)" , display_name )
9261 } else {
93- model . model_id ( ) . to_owned ( )
62+ display_name . to_owned ( )
9463 }
9564 } )
9665 . collect ( ) ;
@@ -119,11 +88,12 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
11988 let selected = & models[ index] ;
12089 let model_id_str = selected. model_id . clone ( ) ;
12190 session. conversation . model = Some ( model_id_str. clone ( ) ) ;
91+ let display_name = selected. model_name ( ) . unwrap_or ( selected. model_id ( ) ) ;
12292
12393 queue ! (
12494 session. stderr,
12595 style:: Print ( "\n " ) ,
126- style:: Print ( format!( " Using {}\n \n " , model_id_str ) ) ,
96+ style:: Print ( format!( " Using {}\n \n " , display_name ) ) ,
12797 style:: ResetColor ,
12898 style:: SetForegroundColor ( Color :: Reset ) ,
12999 style:: SetBackgroundColor ( Color :: Reset ) ,
@@ -160,60 +130,41 @@ pub async fn default_model_id(os: &Os) -> String {
160130 "claude-4-sonnet" . to_string ( )
161131}
162132
163- /// Returns the available models for use.
164- pub async fn get_model_options ( os : & Os ) -> Result < Vec < ModelOption > , ChatError > {
165- let mut model_options = MODEL_OPTIONS . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
166-
167- // GPT OSS is only accessible in IAD.
133+ /// Get available models with caching support
134+ pub async fn get_available_models ( os : & Os ) -> Result < ( Vec < Model > , Option < Model > ) , ChatError > {
168135 let endpoint = Endpoint :: configured_value ( & os. database ) ;
169- if endpoint. region ( ) . as_ref ( ) != "us-east-1" {
170- return Ok ( model_options) ;
171- }
136+ let region = endpoint. region ( ) . as_ref ( ) ;
172137
173- model_options. push ( GPT_OSS_120B ) ;
174- Ok ( model_options)
138+ os. client
139+ . get_available_models ( region)
140+ . await
141+ . map_err ( |e| ChatError :: Custom ( format ! ( "Failed to fetch available models: {}" , e) . into ( ) ) )
175142}
176143
177144/// Returns the context window length in tokens for the given model_id.
178- pub fn context_window_tokens ( model_id : Option < & str > ) -> usize {
145+ /// Uses cached model data when available
146+ pub async fn context_window_tokens ( model_id : Option < & str > , os : & Os ) -> usize {
179147 const DEFAULT_CONTEXT_WINDOW_LENGTH : usize = 200_000 ;
180148
149+ // If no model_id provided, return default
181150 let Some ( model_id) = model_id else {
182151 return DEFAULT_CONTEXT_WINDOW_LENGTH ;
183152 } ;
184153
185- MODEL_OPTIONS
186- . iter ( )
187- . chain ( std:: iter:: once ( & GPT_OSS_120B ) )
188- . find ( |m| m. model_id == model_id)
189- . map_or ( DEFAULT_CONTEXT_WINDOW_LENGTH , |m| m. context_window_tokens )
190- }
191-
192- /// Returns the available models for use.
193- pub async fn get_model_options ( os : & Os ) -> Result < Vec < ModelOption > , ChatError > {
194- let mut model_options = MODEL_OPTIONS . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
195-
196- // GPT OSS is only accessible in IAD.
197- let endpoint = Endpoint :: configured_value ( & os. database ) ;
198- if endpoint. region ( ) . as_ref ( ) != "us-east-1" {
199- return Ok ( model_options) ;
200- }
201-
202- model_options. push ( GPT_OSS_120B ) ;
203- Ok ( model_options)
204- }
205-
206- /// Returns the context window length in tokens for the given model_id.
207- pub fn context_window_tokens ( model_id : Option < & str > ) -> usize {
208- const DEFAULT_CONTEXT_WINDOW_LENGTH : usize = 200_000 ;
209-
210- let Some ( model_id) = model_id else {
211- return DEFAULT_CONTEXT_WINDOW_LENGTH ;
154+ // Try to get from cached models first
155+ let ( models, _) = match get_available_models ( os) . await {
156+ Ok ( models) => models,
157+ Err ( _) => {
158+ // If we can't get models, return default
159+ return DEFAULT_CONTEXT_WINDOW_LENGTH ;
160+ } ,
212161 } ;
213162
214- MODEL_OPTIONS
163+ models
215164 . iter ( )
216- . chain ( std:: iter:: once ( & GPT_OSS_120B ) )
217- . find ( |m| m. model_id == model_id)
218- . map_or ( DEFAULT_CONTEXT_WINDOW_LENGTH , |m| m. context_window_tokens )
165+ . find ( |m| m. model_id ( ) == model_id)
166+ . and_then ( |m| m. token_limits ( ) )
167+ . and_then ( |limits| limits. max_input_tokens ( ) )
168+ . map ( |tokens| tokens as usize )
169+ . unwrap_or ( DEFAULT_CONTEXT_WINDOW_LENGTH )
219170}
0 commit comments