@@ -68,7 +68,18 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6868// TODO(bskiser): confirm timeout is updated to an appropriate value?
6969const DEFAULT_TIMEOUT_DURATION : Duration = Duration :: from_secs ( 60 * 5 ) ;
7070
71- type ModelListResult = ( Vec < Model > , Model ) ;
71+ #[ derive( Clone , Debug ) ]
72+ pub struct ModelListResult {
73+ pub models : Vec < Model > ,
74+ pub default_model : Model ,
75+ }
76+
77+ impl From < ModelListResult > for ( Vec < Model > , Model ) {
78+ fn from ( v : ModelListResult ) -> Self {
79+ ( v. models , v. default_model )
80+ }
81+ }
82+
7283type ModelCache = Arc < RwLock < Option < ModelListResult > > > ;
7384
7485#[ derive( Clone , Debug ) ]
@@ -240,22 +251,18 @@ impl ApiClient {
240251 Ok ( profiles)
241252 }
242253
243- pub async fn list_available_models ( & self ) -> Result < ( Vec < Model > , Model ) , ApiClientError > {
254+ pub async fn list_available_models ( & self ) -> Result < ModelListResult , ApiClientError > {
244255 if cfg ! ( test) {
245- return Ok ( (
246- vec ! [
247- Model :: builder( )
248- . model_id( "model-1" )
249- . description( "Test Model 1" )
250- . build( )
251- . unwrap( ) ,
252- ] ,
253- Model :: builder ( )
254- . model_id ( "model-1" )
255- . description ( "Test Model 1" )
256- . build ( )
257- . unwrap ( ) ,
258- ) ) ;
256+ let m = Model :: builder ( )
257+ . model_id ( "model-1" )
258+ . description ( "Test Model 1" )
259+ . build ( )
260+ . unwrap ( ) ;
261+
262+ return Ok ( ModelListResult {
263+ models : vec ! [ m. clone( ) ] ,
264+ default_model : m,
265+ } ) ;
259266 }
260267
261268 let mut models = Vec :: new ( ) ;
@@ -276,10 +283,10 @@ impl ApiClient {
276283 }
277284 }
278285 let default_model = default_model. ok_or_else ( || ApiClientError :: DefaultModelNotFound ) ?;
279- Ok ( ( models, default_model) )
286+ Ok ( ModelListResult { models, default_model } )
280287 }
281288
282- pub async fn list_available_models_cached ( & self ) -> Result < ( Vec < Model > , Model ) , ApiClientError > {
289+ pub async fn list_available_models_cached ( & self ) -> Result < ModelListResult , ApiClientError > {
283290 {
284291 let cache = self . model_cache . read ( ) . await ;
285292 if let Some ( cached) = cache. as_ref ( ) {
@@ -303,9 +310,8 @@ impl ApiClient {
303310 tracing:: info!( "Model cache invalidated" ) ;
304311 }
305312
306- pub async fn get_available_models ( & self , _region : & str ) -> Result < ( Vec < Model > , Model ) , ApiClientError > {
307- let ( models, default_model) = self . list_available_models_cached ( ) . await ?;
308-
313+ pub async fn get_available_models ( & self , _region : & str ) -> Result < ModelListResult , ApiClientError > {
314+ let res = self . list_available_models_cached ( ) . await ?;
309315 // TODO: Once we have access to gpt-oss, add back.
310316 // if region == "us-east-1" {
311317 // let gpt_oss = Model::builder()
@@ -318,7 +324,7 @@ impl ApiClient {
318324 // models.push(gpt_oss);
319325 // }
320326
321- Ok ( ( models , default_model ) )
327+ Ok ( res . into ( ) )
322328 }
323329
324330 pub async fn create_subscription_token ( & self ) -> Result < CreateSubscriptionTokenOutput , ApiClientError > {
0 commit comments