@@ -12,7 +12,9 @@ use std::time::Duration;
12
12
13
13
use amzn_codewhisperer_client:: Client as CodewhispererClient ;
14
14
use amzn_codewhisperer_client:: operation:: create_subscription_token:: CreateSubscriptionTokenOutput ;
15
+ use amzn_codewhisperer_client:: types:: Origin :: Cli ;
15
16
use amzn_codewhisperer_client:: types:: {
17
+ Model ,
16
18
OptOutPreference ,
17
19
SubscriptionStatus ,
18
20
TelemetryEvent ,
@@ -32,6 +34,7 @@ pub use error::ApiClientError;
32
34
use parking_lot:: Mutex ;
33
35
pub use profile:: list_available_profiles;
34
36
use serde_json:: Map ;
37
+ use tokio:: sync:: RwLock ;
35
38
use tracing:: {
36
39
debug,
37
40
error,
@@ -66,13 +69,28 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
66
69
// TODO(bskiser): confirm timeout is updated to an appropriate value?
67
70
const DEFAULT_TIMEOUT_DURATION : Duration = Duration :: from_secs ( 60 * 5 ) ;
68
71
72
+ #[ derive( Clone , Debug ) ]
73
+ pub struct ModelListResult {
74
+ pub models : Vec < Model > ,
75
+ pub default_model : Model ,
76
+ }
77
+
78
+ impl From < ModelListResult > for ( Vec < Model > , Model ) {
79
+ fn from ( v : ModelListResult ) -> Self {
80
+ ( v. models , v. default_model )
81
+ }
82
+ }
83
+
84
+ type ModelCache = Arc < RwLock < Option < ModelListResult > > > ;
85
+
69
86
#[ derive( Clone , Debug ) ]
70
87
pub struct ApiClient {
71
88
client : CodewhispererClient ,
72
89
streaming_client : Option < CodewhispererStreamingClient > ,
73
90
sigv4_streaming_client : Option < QDeveloperStreamingClient > ,
74
91
mock_client : Option < Arc < Mutex < std:: vec:: IntoIter < Vec < ChatResponseStream > > > > > ,
75
92
profile : Option < AuthProfile > ,
93
+ model_cache : ModelCache ,
76
94
}
77
95
78
96
impl ApiClient {
@@ -112,6 +130,7 @@ impl ApiClient {
112
130
sigv4_streaming_client : None ,
113
131
mock_client : None ,
114
132
profile : None ,
133
+ model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
115
134
} ;
116
135
117
136
if let Ok ( json) = env. get ( "Q_MOCK_CHAT_RESPONSE" ) {
@@ -181,6 +200,7 @@ impl ApiClient {
181
200
sigv4_streaming_client,
182
201
mock_client : None ,
183
202
profile,
203
+ model_cache : Arc :: new ( RwLock :: new ( None ) ) ,
184
204
} )
185
205
}
186
206
@@ -234,6 +254,82 @@ impl ApiClient {
234
254
Ok ( profiles)
235
255
}
236
256
257
+ pub async fn list_available_models ( & self ) -> Result < ModelListResult , ApiClientError > {
258
+ if cfg ! ( test) {
259
+ let m = Model :: builder ( )
260
+ . model_id ( "model-1" )
261
+ . description ( "Test Model 1" )
262
+ . build ( )
263
+ . unwrap ( ) ;
264
+
265
+ return Ok ( ModelListResult {
266
+ models : vec ! [ m. clone( ) ] ,
267
+ default_model : m,
268
+ } ) ;
269
+ }
270
+
271
+ let mut models = Vec :: new ( ) ;
272
+ let mut default_model = None ;
273
+ let request = self
274
+ . client
275
+ . list_available_models ( )
276
+ . set_origin ( Some ( Cli ) )
277
+ . set_profile_arn ( self . profile . as_ref ( ) . map ( |p| p. arn . clone ( ) ) ) ;
278
+ let mut paginator = request. into_paginator ( ) . send ( ) ;
279
+
280
+ while let Some ( result) = paginator. next ( ) . await {
281
+ let models_output = result?;
282
+ models. extend ( models_output. models ( ) . iter ( ) . cloned ( ) ) ;
283
+
284
+ if default_model. is_none ( ) {
285
+ default_model = Some ( models_output. default_model ( ) . clone ( ) ) ;
286
+ }
287
+ }
288
+ let default_model = default_model. ok_or_else ( || ApiClientError :: DefaultModelNotFound ) ?;
289
+ Ok ( ModelListResult { models, default_model } )
290
+ }
291
+
292
+ pub async fn list_available_models_cached ( & self ) -> Result < ModelListResult , ApiClientError > {
293
+ {
294
+ let cache = self . model_cache . read ( ) . await ;
295
+ if let Some ( cached) = cache. as_ref ( ) {
296
+ tracing:: debug!( "Returning cached model list" ) ;
297
+ return Ok ( cached. clone ( ) ) ;
298
+ }
299
+ }
300
+
301
+ tracing:: debug!( "Cache miss, fetching models from list_available_models API" ) ;
302
+ let result = self . list_available_models ( ) . await ?;
303
+ {
304
+ let mut cache = self . model_cache . write ( ) . await ;
305
+ * cache = Some ( result. clone ( ) ) ;
306
+ }
307
+ Ok ( result)
308
+ }
309
+
310
+ pub async fn invalidate_model_cache ( & self ) {
311
+ let mut cache = self . model_cache . write ( ) . await ;
312
+ * cache = None ;
313
+ tracing:: info!( "Model cache invalidated" ) ;
314
+ }
315
+
316
+ pub async fn get_available_models ( & self , _region : & str ) -> Result < ModelListResult , ApiClientError > {
317
+ let res = self . list_available_models_cached ( ) . await ?;
318
+ // TODO: Once we have access to gpt-oss, add back.
319
+ // if region == "us-east-1" {
320
+ // let gpt_oss = Model::builder()
321
+ // .model_id("OPENAI_GPT_OSS_120B_1_0")
322
+ // .model_name("openai-gpt-oss-120b-preview")
323
+ // .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
324
+ // .build()
325
+ // .map_err(ApiClientError::from)?;
326
+
327
+ // models.push(gpt_oss);
328
+ // }
329
+
330
+ Ok ( res)
331
+ }
332
+
237
333
pub async fn create_subscription_token ( & self ) -> Result < CreateSubscriptionTokenOutput , ApiClientError > {
238
334
if cfg ! ( test) {
239
335
return Ok ( CreateSubscriptionTokenOutput :: builder ( )
0 commit comments