Skip to content

Commit 72a95a4

Browse files
committed
replace tuple with ModelListResult struct; cache struct; keep tuple API via From impl
1 parent c4bfb19 commit 72a95a4

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

crates/chat-cli/src/api_client/mod.rs

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,18 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6868
// TODO(bskiser): confirm timeout is updated to an appropriate value?
6969
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);
7070

71-
type ModelListResult = (Vec<Model>, Model);
71+
#[derive(Clone, Debug)]
72+
pub struct ModelListResult {
73+
pub models: Vec<Model>,
74+
pub default_model: Model,
75+
}
76+
77+
impl From<ModelListResult> for (Vec<Model>, Model) {
78+
fn from(v: ModelListResult) -> Self {
79+
(v.models, v.default_model)
80+
}
81+
}
82+
7283
type ModelCache = Arc<RwLock<Option<ModelListResult>>>;
7384

7485
#[derive(Clone, Debug)]
@@ -240,22 +251,18 @@ impl ApiClient {
240251
Ok(profiles)
241252
}
242253

243-
pub async fn list_available_models(&self) -> Result<(Vec<Model>, Model), ApiClientError> {
254+
pub async fn list_available_models(&self) -> Result<ModelListResult, ApiClientError> {
244255
if cfg!(test) {
245-
return Ok((
246-
vec![
247-
Model::builder()
248-
.model_id("model-1")
249-
.description("Test Model 1")
250-
.build()
251-
.unwrap(),
252-
],
253-
Model::builder()
254-
.model_id("model-1")
255-
.description("Test Model 1")
256-
.build()
257-
.unwrap(),
258-
));
256+
let m = Model::builder()
257+
.model_id("model-1")
258+
.description("Test Model 1")
259+
.build()
260+
.unwrap();
261+
262+
return Ok(ModelListResult {
263+
models: vec![m.clone()],
264+
default_model: m,
265+
});
259266
}
260267

261268
let mut models = Vec::new();
@@ -276,10 +283,10 @@ impl ApiClient {
276283
}
277284
}
278285
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
279-
Ok((models, default_model))
286+
Ok(ModelListResult { models, default_model })
280287
}
281288

282-
pub async fn list_available_models_cached(&self) -> Result<(Vec<Model>, Model), ApiClientError> {
289+
pub async fn list_available_models_cached(&self) -> Result<ModelListResult, ApiClientError> {
283290
{
284291
let cache = self.model_cache.read().await;
285292
if let Some(cached) = cache.as_ref() {
@@ -303,9 +310,8 @@ impl ApiClient {
303310
tracing::info!("Model cache invalidated");
304311
}
305312

306-
pub async fn get_available_models(&self, _region: &str) -> Result<(Vec<Model>, Model), ApiClientError> {
307-
let (models, default_model) = self.list_available_models_cached().await?;
308-
313+
pub async fn get_available_models(&self, _region: &str) -> Result<ModelListResult, ApiClientError> {
314+
let res = self.list_available_models_cached().await?;
309315
// TODO: Once we have access to gpt-oss, add back.
310316
// if region == "us-east-1" {
311317
// let gpt_oss = Model::builder()
@@ -318,7 +324,7 @@ impl ApiClient {
318324
// models.push(gpt_oss);
319325
// }
320326

321-
Ok((models, default_model))
327+
Ok(res.into())
322328
}
323329

324330
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {

crates/chat-cli/src/cli/chat/cli/model.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ pub async fn get_available_models(os: &Os) -> Result<(Vec<ModelInfo>, ModelInfo)
225225
let region = endpoint.region().as_ref();
226226

227227
match os.client.get_available_models(region).await {
228-
Ok((api_models, api_default)) => {
229-
let models: Vec<ModelInfo> = api_models.iter().map(ModelInfo::from_api_model).collect();
230-
let default_model = ModelInfo::from_api_model(&api_default);
228+
Ok(api_res) => {
229+
let models: Vec<ModelInfo> = api_res.models.iter().map(ModelInfo::from_api_model).collect();
230+
let default_model = ModelInfo::from_api_model(&api_res.default_model);
231231

232232
tracing::debug!("Successfully fetched {} models from API", models.len());
233233
Ok((models, default_model))

0 commit comments

Comments
 (0)