Skip to content

Commit 754c4d5

Browse files
authored
feat: integrate backend API and remove hardcoded model options (#2419)
1 parent 0c82fc6 commit 754c4d5

File tree

9 files changed

+324
-146
lines changed

9 files changed

+324
-146
lines changed

crates/chat-cli/src/api_client/error.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
22
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
33
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
4+
use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError;
45
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
56
use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError;
67
pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError;
@@ -93,6 +94,12 @@ pub enum ApiClientError {
9394
// Credential errors
9495
#[error("failed to load credentials: {}", .0)]
9596
Credentials(CredentialsError),
97+
98+
#[error(transparent)]
99+
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),
100+
101+
#[error("No default model found in the ListAvailableModels API response")]
102+
DefaultModelNotFound,
96103
}
97104

98105
impl ApiClientError {
@@ -116,6 +123,8 @@ impl ApiClientError {
116123
Self::ModelOverloadedError { status_code, .. } => *status_code,
117124
Self::MonthlyLimitReached { status_code } => *status_code,
118125
Self::Credentials(_e) => None,
126+
Self::ListAvailableModelsError(e) => sdk_status_code(e),
127+
Self::DefaultModelNotFound => None,
119128
}
120129
}
121130
}
@@ -141,6 +150,8 @@ impl ReasonCode for ApiClientError {
141150
Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
142151
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
143152
Self::Credentials(_) => "CredentialsError".to_string(),
153+
Self::ListAvailableModelsError(e) => sdk_error_code(e),
154+
Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(),
144155
}
145156
}
146157
}
@@ -188,6 +199,10 @@ mod tests {
188199
ListAvailableCustomizationsError::unhandled("<unhandled>"),
189200
response(),
190201
)),
202+
ApiClientError::ListAvailableModelsError(SdkError::service_error(
203+
ListAvailableModelsError::unhandled("<unhandled>"),
204+
response(),
205+
)),
191206
ApiClientError::ListAvailableServices(SdkError::service_error(
192207
ListCustomizationsError::unhandled("<unhandled>"),
193208
response(),

crates/chat-cli/src/api_client/mod.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use std::time::Duration;
1212

1313
use amzn_codewhisperer_client::Client as CodewhispererClient;
1414
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
15+
use amzn_codewhisperer_client::types::Origin::Cli;
1516
use amzn_codewhisperer_client::types::{
17+
Model,
1618
OptOutPreference,
1719
SubscriptionStatus,
1820
TelemetryEvent,
@@ -32,6 +34,7 @@ pub use error::ApiClientError;
3234
use parking_lot::Mutex;
3335
pub use profile::list_available_profiles;
3436
use serde_json::Map;
37+
use tokio::sync::RwLock;
3538
use tracing::{
3639
debug,
3740
error,
@@ -66,13 +69,28 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6669
// TODO(bskiser): confirm timeout is updated to an appropriate value?
6770
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);
6871

72+
#[derive(Clone, Debug)]
73+
pub struct ModelListResult {
74+
pub models: Vec<Model>,
75+
pub default_model: Model,
76+
}
77+
78+
impl From<ModelListResult> for (Vec<Model>, Model) {
79+
fn from(v: ModelListResult) -> Self {
80+
(v.models, v.default_model)
81+
}
82+
}
83+
84+
type ModelCache = Arc<RwLock<Option<ModelListResult>>>;
85+
6986
#[derive(Clone, Debug)]
7087
pub struct ApiClient {
7188
client: CodewhispererClient,
7289
streaming_client: Option<CodewhispererStreamingClient>,
7390
sigv4_streaming_client: Option<QDeveloperStreamingClient>,
7491
mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
7592
profile: Option<AuthProfile>,
93+
model_cache: ModelCache,
7694
}
7795

7896
impl ApiClient {
@@ -112,6 +130,7 @@ impl ApiClient {
112130
sigv4_streaming_client: None,
113131
mock_client: None,
114132
profile: None,
133+
model_cache: Arc::new(RwLock::new(None)),
115134
};
116135

117136
if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") {
@@ -181,6 +200,7 @@ impl ApiClient {
181200
sigv4_streaming_client,
182201
mock_client: None,
183202
profile,
203+
model_cache: Arc::new(RwLock::new(None)),
184204
})
185205
}
186206

@@ -234,6 +254,82 @@ impl ApiClient {
234254
Ok(profiles)
235255
}
236256

257+
pub async fn list_available_models(&self) -> Result<ModelListResult, ApiClientError> {
258+
if cfg!(test) {
259+
let m = Model::builder()
260+
.model_id("model-1")
261+
.description("Test Model 1")
262+
.build()
263+
.unwrap();
264+
265+
return Ok(ModelListResult {
266+
models: vec![m.clone()],
267+
default_model: m,
268+
});
269+
}
270+
271+
let mut models = Vec::new();
272+
let mut default_model = None;
273+
let request = self
274+
.client
275+
.list_available_models()
276+
.set_origin(Some(Cli))
277+
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
278+
let mut paginator = request.into_paginator().send();
279+
280+
while let Some(result) = paginator.next().await {
281+
let models_output = result?;
282+
models.extend(models_output.models().iter().cloned());
283+
284+
if default_model.is_none() {
285+
default_model = Some(models_output.default_model().clone());
286+
}
287+
}
288+
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
289+
Ok(ModelListResult { models, default_model })
290+
}
291+
292+
pub async fn list_available_models_cached(&self) -> Result<ModelListResult, ApiClientError> {
293+
{
294+
let cache = self.model_cache.read().await;
295+
if let Some(cached) = cache.as_ref() {
296+
tracing::debug!("Returning cached model list");
297+
return Ok(cached.clone());
298+
}
299+
}
300+
301+
tracing::debug!("Cache miss, fetching models from list_available_models API");
302+
let result = self.list_available_models().await?;
303+
{
304+
let mut cache = self.model_cache.write().await;
305+
*cache = Some(result.clone());
306+
}
307+
Ok(result)
308+
}
309+
310+
pub async fn invalidate_model_cache(&self) {
311+
let mut cache = self.model_cache.write().await;
312+
*cache = None;
313+
tracing::info!("Model cache invalidated");
314+
}
315+
316+
pub async fn get_available_models(&self, _region: &str) -> Result<ModelListResult, ApiClientError> {
317+
let res = self.list_available_models_cached().await?;
318+
// TODO: Once we have access to gpt-oss, add back.
319+
// if region == "us-east-1" {
320+
// let gpt_oss = Model::builder()
321+
// .model_id("OPENAI_GPT_OSS_120B_1_0")
322+
// .model_name("openai-gpt-oss-120b-preview")
323+
// .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
324+
// .build()
325+
// .map_err(ApiClientError::from)?;
326+
327+
// models.push(gpt_oss);
328+
// }
329+
330+
Ok(res)
331+
}
332+
237333
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
238334
if cfg!(test) {
239335
return Ok(CreateSubscriptionTokenOutput::builder()

crates/chat-cli/src/auth/builder_id.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ impl BuilderIdToken {
480480

481481
/// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`),
482482
/// this implies the user will use midway for private specs
483+
#[allow(dead_code)]
483484
pub fn is_amzn_user(&self) -> bool {
484485
matches!(&self.start_url, Some(url) if url == AMZN_START_URL)
485486
}

crates/chat-cli/src/cli/chat/cli/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ impl ContextSubcommand {
222222
execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?;
223223
}
224224

225-
let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref());
225+
let context_files_max_size = calc_max_context_files_size(session.conversation.model_info.as_ref());
226226
let mut files_as_vec = profile_context_files
227227
.iter()
228228
.map(|(path, content, _)| (path.clone(), content.clone()))

0 commit comments

Comments
 (0)