@@ -12,7 +12,9 @@ use std::time::Duration;
1212
1313use amzn_codewhisperer_client:: Client as CodewhispererClient ;
1414use amzn_codewhisperer_client:: operation:: create_subscription_token:: CreateSubscriptionTokenOutput ;
15+ use amzn_codewhisperer_client:: types:: Origin :: Cli ;
1516use amzn_codewhisperer_client:: types:: {
17+ Model ,
1618 OptOutPreference ,
1719 SubscriptionStatus ,
1820 TelemetryEvent ,
@@ -32,6 +34,7 @@ pub use error::ApiClientError;
3234use parking_lot:: Mutex ;
3335pub use profile:: list_available_profiles;
3436use serde_json:: Map ;
37+ use tokio:: sync:: RwLock ;
3538use tracing:: {
3639 debug,
3740 error,
@@ -66,13 +69,28 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6669// TODO(bskiser): confirm timeout is updated to an appropriate value?
6770const DEFAULT_TIMEOUT_DURATION : Duration = Duration :: from_secs ( 60 * 5 ) ;
6871
72+ #[ derive( Clone , Debug ) ]
73+ pub struct ModelListResult {
74+ pub models : Vec < Model > ,
75+ pub default_model : Model ,
76+ }
77+
78+ impl From < ModelListResult > for ( Vec < Model > , Model ) {
79+ fn from ( v : ModelListResult ) -> Self {
80+ ( v. models , v. default_model )
81+ }
82+ }
83+
84+ type ModelCache = Arc < RwLock < Option < ModelListResult > > > ;
85+
6986#[ derive( Clone , Debug ) ]
7087pub struct ApiClient {
7188 client : CodewhispererClient ,
7289 streaming_client : Option < CodewhispererStreamingClient > ,
7390 sigv4_streaming_client : Option < QDeveloperStreamingClient > ,
7491 mock_client : Option < Arc < Mutex < std:: vec:: IntoIter < Vec < ChatResponseStream > > > > > ,
7592 profile : Option < AuthProfile > ,
93+ model_cache : ModelCache ,
7694}
7795
7896impl ApiClient {
@@ -112,6 +130,7 @@ impl ApiClient {
112130 sigv4_streaming_client : None ,
113131 mock_client : None ,
114132 profile : None ,
133+ model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
115134 } ;
116135
117136 if let Ok ( json) = env. get ( "Q_MOCK_CHAT_RESPONSE" ) {
@@ -181,6 +200,7 @@ impl ApiClient {
181200 sigv4_streaming_client,
182201 mock_client : None ,
183202 profile,
203+ model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
184204 } )
185205 }
186206
@@ -234,6 +254,82 @@ impl ApiClient {
234254 Ok ( profiles)
235255 }
236256
257+ pub async fn list_available_models ( & self ) -> Result < ModelListResult , ApiClientError > {
258+ if cfg ! ( test) {
259+ let m = Model :: builder ( )
260+ . model_id ( "model-1" )
261+ . description ( "Test Model 1" )
262+ . build ( )
263+ . unwrap ( ) ;
264+
265+ return Ok ( ModelListResult {
266+ models : vec ! [ m. clone( ) ] ,
267+ default_model : m,
268+ } ) ;
269+ }
270+
271+ let mut models = Vec :: new ( ) ;
272+ let mut default_model = None ;
273+ let request = self
274+ . client
275+ . list_available_models ( )
276+ . set_origin ( Some ( Cli ) )
277+ . set_profile_arn ( self . profile . as_ref ( ) . map ( |p| p. arn . clone ( ) ) ) ;
278+ let mut paginator = request. into_paginator ( ) . send ( ) ;
279+
280+ while let Some ( result) = paginator. next ( ) . await {
281+ let models_output = result?;
282+ models. extend ( models_output. models ( ) . iter ( ) . cloned ( ) ) ;
283+
284+ if default_model. is_none ( ) {
285+ default_model = Some ( models_output. default_model ( ) . clone ( ) ) ;
286+ }
287+ }
288+ let default_model = default_model. ok_or_else ( || ApiClientError :: DefaultModelNotFound ) ?;
289+ Ok ( ModelListResult { models, default_model } )
290+ }
291+
292+ pub async fn list_available_models_cached ( & self ) -> Result < ModelListResult , ApiClientError > {
293+ {
294+ let cache = self . model_cache . read ( ) . await ;
295+ if let Some ( cached) = cache. as_ref ( ) {
296+ tracing:: debug!( "Returning cached model list" ) ;
297+ return Ok ( cached. clone ( ) ) ;
298+ }
299+ }
300+
301+ tracing:: debug!( "Cache miss, fetching models from list_available_models API" ) ;
302+ let result = self . list_available_models ( ) . await ?;
303+ {
304+ let mut cache = self . model_cache . write ( ) . await ;
305+ * cache = Some ( result. clone ( ) ) ;
306+ }
307+ Ok ( result)
308+ }
309+
310+ pub async fn invalidate_model_cache ( & self ) {
311+ let mut cache = self . model_cache . write ( ) . await ;
312+ * cache = None ;
313+ tracing:: info!( "Model cache invalidated" ) ;
314+ }
315+
316+ pub async fn get_available_models ( & self , _region : & str ) -> Result < ModelListResult , ApiClientError > {
317+ let res = self . list_available_models_cached ( ) . await ?;
318+ // TODO: Once we have access to gpt-oss, add back.
319+ // if region == "us-east-1" {
320+ // let gpt_oss = Model::builder()
321+ // .model_id("OPENAI_GPT_OSS_120B_1_0")
322+ // .model_name("openai-gpt-oss-120b-preview")
323+ // .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
324+ // .build()
325+ // .map_err(ApiClientError::from)?;
326+
327+ // models.push(gpt_oss);
328+ // }
329+
330+ Ok ( res)
331+ }
332+
237333 pub async fn create_subscription_token ( & self ) -> Result < CreateSubscriptionTokenOutput , ApiClientError > {
238334 if cfg ! ( test) {
239335 return Ok ( CreateSubscriptionTokenOutput :: builder ( )
0 commit comments