Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions crates/chat-cli/src/api_client/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError;
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError;
pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError;
Expand Down Expand Up @@ -94,12 +93,6 @@ pub enum ApiClientError {
// Credential errors
#[error("failed to load credentials: {}", .0)]
Credentials(CredentialsError),

#[error(transparent)]
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),

#[error("No default model found in the ListAvailableModels API response")]
DefaultModelNotFound,
}

impl ApiClientError {
Expand All @@ -123,8 +116,6 @@ impl ApiClientError {
Self::ModelOverloadedError { status_code, .. } => *status_code,
Self::MonthlyLimitReached { status_code } => *status_code,
Self::Credentials(_e) => None,
Self::ListAvailableModelsError(e) => sdk_status_code(e),
Self::DefaultModelNotFound => None,
}
}
}
Expand All @@ -150,8 +141,6 @@ impl ReasonCode for ApiClientError {
Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
Self::Credentials(_) => "CredentialsError".to_string(),
Self::ListAvailableModelsError(e) => sdk_error_code(e),
Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(),
}
}
}
Expand Down Expand Up @@ -199,10 +188,6 @@ mod tests {
ListAvailableCustomizationsError::unhandled("<unhandled>"),
response(),
)),
ApiClientError::ListAvailableModelsError(SdkError::service_error(
ListAvailableModelsError::unhandled("<unhandled>"),
response(),
)),
ApiClientError::ListAvailableServices(SdkError::service_error(
ListCustomizationsError::unhandled("<unhandled>"),
response(),
Expand Down
96 changes: 0 additions & 96 deletions crates/chat-cli/src/api_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ use std::time::Duration;

use amzn_codewhisperer_client::Client as CodewhispererClient;
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
use amzn_codewhisperer_client::types::Origin::Cli;
use amzn_codewhisperer_client::types::{
Model,
OptOutPreference,
SubscriptionStatus,
TelemetryEvent,
Expand All @@ -34,7 +32,6 @@ pub use error::ApiClientError;
use parking_lot::Mutex;
pub use profile::list_available_profiles;
use serde_json::Map;
use tokio::sync::RwLock;
use tracing::{
debug,
error,
Expand Down Expand Up @@ -69,28 +66,13 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
// TODO(bskiser): confirm timeout is updated to an appropriate value?
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);

#[derive(Clone, Debug)]
pub struct ModelListResult {
pub models: Vec<Model>,
pub default_model: Model,
}

impl From<ModelListResult> for (Vec<Model>, Model) {
fn from(v: ModelListResult) -> Self {
(v.models, v.default_model)
}
}

type ModelCache = Arc<RwLock<Option<ModelListResult>>>;

#[derive(Clone, Debug)]
pub struct ApiClient {
client: CodewhispererClient,
streaming_client: Option<CodewhispererStreamingClient>,
sigv4_streaming_client: Option<QDeveloperStreamingClient>,
mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
profile: Option<AuthProfile>,
model_cache: ModelCache,
}

impl ApiClient {
Expand Down Expand Up @@ -130,7 +112,6 @@ impl ApiClient {
sigv4_streaming_client: None,
mock_client: None,
profile: None,
model_cache: Arc::new(RwLock::new(None)),
};

if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") {
Expand Down Expand Up @@ -200,7 +181,6 @@ impl ApiClient {
sigv4_streaming_client,
mock_client: None,
profile,
model_cache: Arc::new(RwLock::new(None)),
})
}

Expand Down Expand Up @@ -254,82 +234,6 @@ impl ApiClient {
Ok(profiles)
}

pub async fn list_available_models(&self) -> Result<ModelListResult, ApiClientError> {
if cfg!(test) {
let m = Model::builder()
.model_id("model-1")
.description("Test Model 1")
.build()
.unwrap();

return Ok(ModelListResult {
models: vec![m.clone()],
default_model: m,
});
}

let mut models = Vec::new();
let mut default_model = None;
let request = self
.client
.list_available_models()
.set_origin(Some(Cli))
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
let mut paginator = request.into_paginator().send();

while let Some(result) = paginator.next().await {
let models_output = result?;
models.extend(models_output.models().iter().cloned());

if default_model.is_none() {
default_model = Some(models_output.default_model().clone());
}
}
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
Ok(ModelListResult { models, default_model })
}

pub async fn list_available_models_cached(&self) -> Result<ModelListResult, ApiClientError> {
{
let cache = self.model_cache.read().await;
if let Some(cached) = cache.as_ref() {
tracing::debug!("Returning cached model list");
return Ok(cached.clone());
}
}

tracing::debug!("Cache miss, fetching models from list_available_models API");
let result = self.list_available_models().await?;
{
let mut cache = self.model_cache.write().await;
*cache = Some(result.clone());
}
Ok(result)
}

pub async fn invalidate_model_cache(&self) {
let mut cache = self.model_cache.write().await;
*cache = None;
tracing::info!("Model cache invalidated");
}

pub async fn get_available_models(&self, _region: &str) -> Result<ModelListResult, ApiClientError> {
let res = self.list_available_models_cached().await?;
// TODO: Once we have access to gpt-oss, add back.
// if region == "us-east-1" {
// let gpt_oss = Model::builder()
// .model_id("OPENAI_GPT_OSS_120B_1_0")
// .model_name("openai-gpt-oss-120b-preview")
// .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
// .build()
// .map_err(ApiClientError::from)?;

// models.push(gpt_oss);
// }

Ok(res)
}

pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
if cfg!(test) {
return Ok(CreateSubscriptionTokenOutput::builder()
Expand Down
1 change: 0 additions & 1 deletion crates/chat-cli/src/auth/builder_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ impl BuilderIdToken {

/// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`),
/// this implies the user will use midway for private specs
#[allow(dead_code)]
pub fn is_amzn_user(&self) -> bool {
matches!(&self.start_url, Some(url) if url == AMZN_START_URL)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/chat-cli/src/cli/chat/cli/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl ContextSubcommand {
execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?;
}

let context_files_max_size = calc_max_context_files_size(session.conversation.model_info.as_ref());
let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref());
let mut files_as_vec = profile_context_files
.iter()
.map(|(path, content, _)| (path.clone(), content.clone()))
Expand Down
Loading
Loading