@@ -12,9 +12,7 @@ use std::time::Duration;
1212
1313use amzn_codewhisperer_client:: Client as CodewhispererClient ;
1414use amzn_codewhisperer_client:: operation:: create_subscription_token:: CreateSubscriptionTokenOutput ;
15- use amzn_codewhisperer_client:: types:: Origin :: Cli ;
1615use amzn_codewhisperer_client:: types:: {
17- Model ,
1816 OptOutPreference ,
1917 SubscriptionStatus ,
2018 TelemetryEvent ,
@@ -34,7 +32,6 @@ pub use error::ApiClientError;
3432use parking_lot:: Mutex ;
3533pub use profile:: list_available_profiles;
3634use serde_json:: Map ;
37- use tokio:: sync:: RwLock ;
3835use tracing:: {
3936 debug,
4037 error,
@@ -69,28 +66,13 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6966// TODO(bskiser): confirm timeout is updated to an appropriate value?
7067const DEFAULT_TIMEOUT_DURATION : Duration = Duration :: from_secs ( 60 * 5 ) ;
7168
72- #[ derive( Clone , Debug ) ]
73- pub struct ModelListResult {
74- pub models : Vec < Model > ,
75- pub default_model : Model ,
76- }
77-
78- impl From < ModelListResult > for ( Vec < Model > , Model ) {
79- fn from ( v : ModelListResult ) -> Self {
80- ( v. models , v. default_model )
81- }
82- }
83-
84- type ModelCache = Arc < RwLock < Option < ModelListResult > > > ;
85-
8669#[ derive( Clone , Debug ) ]
8770pub struct ApiClient {
8871 client : CodewhispererClient ,
8972 streaming_client : Option < CodewhispererStreamingClient > ,
9073 sigv4_streaming_client : Option < QDeveloperStreamingClient > ,
9174 mock_client : Option < Arc < Mutex < std:: vec:: IntoIter < Vec < ChatResponseStream > > > > > ,
9275 profile : Option < AuthProfile > ,
93- model_cache : ModelCache ,
9476}
9577
9678impl ApiClient {
@@ -130,7 +112,6 @@ impl ApiClient {
130112 sigv4_streaming_client : None ,
131113 mock_client : None ,
132114 profile : None ,
133- model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
134115 } ;
135116
136117 if let Ok ( json) = env. get ( "Q_MOCK_CHAT_RESPONSE" ) {
@@ -200,7 +181,6 @@ impl ApiClient {
200181 sigv4_streaming_client,
201182 mock_client : None ,
202183 profile,
203- model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
204184 } )
205185 }
206186
@@ -254,82 +234,6 @@ impl ApiClient {
254234 Ok ( profiles)
255235 }
256236
257- pub async fn list_available_models ( & self ) -> Result < ModelListResult , ApiClientError > {
258- if cfg ! ( test) {
259- let m = Model :: builder ( )
260- . model_id ( "model-1" )
261- . description ( "Test Model 1" )
262- . build ( )
263- . unwrap ( ) ;
264-
265- return Ok ( ModelListResult {
266- models : vec ! [ m. clone( ) ] ,
267- default_model : m,
268- } ) ;
269- }
270-
271- let mut models = Vec :: new ( ) ;
272- let mut default_model = None ;
273- let request = self
274- . client
275- . list_available_models ( )
276- . set_origin ( Some ( Cli ) )
277- . set_profile_arn ( self . profile . as_ref ( ) . map ( |p| p. arn . clone ( ) ) ) ;
278- let mut paginator = request. into_paginator ( ) . send ( ) ;
279-
280- while let Some ( result) = paginator. next ( ) . await {
281- let models_output = result?;
282- models. extend ( models_output. models ( ) . iter ( ) . cloned ( ) ) ;
283-
284- if default_model. is_none ( ) {
285- default_model = Some ( models_output. default_model ( ) . clone ( ) ) ;
286- }
287- }
288- let default_model = default_model. ok_or_else ( || ApiClientError :: DefaultModelNotFound ) ?;
289- Ok ( ModelListResult { models, default_model } )
290- }
291-
292- pub async fn list_available_models_cached ( & self ) -> Result < ModelListResult , ApiClientError > {
293- {
294- let cache = self . model_cache . read ( ) . await ;
295- if let Some ( cached) = cache. as_ref ( ) {
296- tracing:: debug!( "Returning cached model list" ) ;
297- return Ok ( cached. clone ( ) ) ;
298- }
299- }
300-
301- tracing:: debug!( "Cache miss, fetching models from list_available_models API" ) ;
302- let result = self . list_available_models ( ) . await ?;
303- {
304- let mut cache = self . model_cache . write ( ) . await ;
305- * cache = Some ( result. clone ( ) ) ;
306- }
307- Ok ( result)
308- }
309-
310- pub async fn invalidate_model_cache ( & self ) {
311- let mut cache = self . model_cache . write ( ) . await ;
312- * cache = None ;
313- tracing:: info!( "Model cache invalidated" ) ;
314- }
315-
316- pub async fn get_available_models ( & self , _region : & str ) -> Result < ModelListResult , ApiClientError > {
317- let res = self . list_available_models_cached ( ) . await ?;
318- // TODO: Once we have access to gpt-oss, add back.
319- // if region == "us-east-1" {
320- // let gpt_oss = Model::builder()
321- // .model_id("OPENAI_GPT_OSS_120B_1_0")
322- // .model_name("openai-gpt-oss-120b-preview")
323- // .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
324- // .build()
325- // .map_err(ApiClientError::from)?;
326-
327- // models.push(gpt_oss);
328- // }
329-
330- Ok ( res)
331- }
332-
333237 pub async fn create_subscription_token ( & self ) -> Result < CreateSubscriptionTokenOutput , ApiClientError > {
334238 if cfg ! ( test) {
335239 return Ok ( CreateSubscriptionTokenOutput :: builder ( )
0 commit comments