Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/chat-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ homepage.workspace = true
publish.workspace = true
version.workspace = true
license.workspace = true
default-run = "chat_cli"

[lints]
workspace = true
Expand Down
77 changes: 32 additions & 45 deletions crates/chat-cli/src/api_client/clients/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@ use tracing::error;

use super::shared::bearer_sdk_config;
use crate::api_client::interceptor::opt_out::OptOutInterceptor;
use crate::api_client::profile::Profile;
use crate::api_client::{
ApiClientError,
Endpoint,
};
use crate::auth::AuthError;
use crate::auth::builder_id::BearerResolver;
use crate::aws_common::{
UserAgentOverrideInterceptor,
app_name,
};
use crate::database::{
AuthProfile,
Database,
};

mod inner {
use amzn_codewhisperer_client::Client as CodewhispererClient;
Expand All @@ -32,76 +36,60 @@ mod inner {
#[derive(Clone, Debug)]
pub struct Client {
inner: inner::Inner,
profile_arn: Option<String>,
profile: Option<AuthProfile>,
}

impl Client {
pub async fn new() -> Result<Client, ApiClientError> {
pub async fn new(database: &mut Database, endpoint: Option<Endpoint>) -> Result<Client, AuthError> {
if cfg!(test) {
return Ok(Self {
inner: inner::Inner::Mock,
profile_arn: None,
profile: None,
});
}

let endpoint = Endpoint::load_codewhisperer();
Ok(Self::new_codewhisperer_client(&endpoint).await)
}

pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self {
let conf_builder: amzn_codewhisperer_client::config::Builder = (&bearer_sdk_config(endpoint).await).into();
let endpoint = endpoint.unwrap_or(Endpoint::load_codewhisperer(database));
let conf_builder: amzn_codewhisperer_client::config::Builder =
(&bearer_sdk_config(database, &endpoint).await).into();
let conf = conf_builder
.http_client(crate::aws_common::http_client::client())
.interceptor(OptOutInterceptor::new())
.interceptor(OptOutInterceptor::new(database))
.interceptor(UserAgentOverrideInterceptor::new())
.bearer_token_resolver(BearerResolver)
.bearer_token_resolver(BearerResolver::new(database).await?)
.app_name(app_name())
.endpoint_url(endpoint.url())
.build();

let inner = inner::Inner::Codewhisperer(CodewhispererClient::from_conf(conf));

let profile_arn = match crate::settings::state::get_value("api.codewhisperer.profile") {
Ok(Some(profile)) => match profile.get("arn") {
Some(arn) => match arn.as_str() {
Some(arn) => Some(arn.to_string()),
None => {
error!("Stored arn is not a string. Instead it was: {arn}");
None
},
},
None => {
error!("Stored profile does not contain an arn. Instead it was: {profile}");
None
},
},
Ok(None) => None,
let profile = match database.get_auth_profile() {
Ok(profile) => profile,
Err(err) => {
error!("Failed to retrieve profile: {}", err);
error!("Failed to get auth profile: {err}");
None
},
};

Self { inner, profile_arn }
Ok(Self { inner, profile })
}

// .telemetry_event(TelemetryEvent::UserTriggerDecisionEvent(user_trigger_decision_event))
// .user_context(user_context)
// .opt_out_preference(opt_out_preference)
pub async fn send_telemetry_event(
&self,
telemetry_event: TelemetryEvent,
user_context: UserContext,
opt_out: OptOutPreference,
telemetry_enabled: bool,
) -> Result<(), ApiClientError> {
match &self.inner {
inner::Inner::Codewhisperer(client) => {
let _ = client
.send_telemetry_event()
.telemetry_event(telemetry_event)
.user_context(user_context)
.opt_out_preference(opt_out)
.set_profile_arn(self.profile_arn.clone())
.opt_out_preference(match telemetry_enabled {
true => OptOutPreference::OptIn,
false => OptOutPreference::OptOut,
})
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()))
.send()
.await;
Ok(())
Expand All @@ -110,23 +98,23 @@ impl Client {
}
}

pub async fn list_available_profiles(&self) -> Result<Vec<Profile>, ApiClientError> {
pub async fn list_available_profiles(&self) -> Result<Vec<AuthProfile>, ApiClientError> {
match &self.inner {
inner::Inner::Codewhisperer(client) => {
let mut profiles = vec![];
let mut client = client.list_available_profiles().into_paginator().send();
while let Some(profiles_output) = client.next().await {
profiles.extend(profiles_output?.profiles().iter().cloned().map(Profile::from));
profiles.extend(profiles_output?.profiles().iter().cloned().map(AuthProfile::from));
}

Ok(profiles)
},
inner::Inner::Mock => Ok(vec![
Profile {
AuthProfile {
arn: "my:arn:1".to_owned(),
profile_name: "MyProfile".to_owned(),
},
Profile {
AuthProfile {
arn: "my:arn:2".to_owned(),
profile_name: "MyOtherProfile".to_owned(),
},
Expand All @@ -147,15 +135,14 @@ mod tests {

#[tokio::test]
async fn create_clients() {
let endpoint = Endpoint::load_codewhisperer();

let _ = Client::new().await;
let _ = Client::new_codewhisperer_client(&endpoint).await;
let mut database = crate::database::Database::new().await.unwrap();
let _ = Client::new(&mut database, None).await;
}

#[tokio::test]
async fn test_mock() {
let client = Client::new().await.unwrap();
let mut database = crate::database::Database::new().await.unwrap();
let client = Client::new(&mut database, None).await.unwrap();
client
.send_telemetry_event(
TelemetryEvent::ChatAddMessageEvent(
Expand All @@ -171,7 +158,7 @@ mod tests {
.product("<product>")
.build()
.unwrap(),
OptOutPreference::OptIn,
false,
)
.await
.unwrap();
Expand Down
26 changes: 16 additions & 10 deletions crates/chat-cli/src/api_client/clients/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ use crate::api_client::{
Endpoint,
};
use crate::aws_common::behavior_version;
use crate::database::Database;
use crate::database::settings::Setting;

// TODO(bskiser): confirm timeout is updated to an appropriate value?
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);

pub(crate) fn timeout_config() -> TimeoutConfig {
let timeout = crate::settings::settings::get_int("api.timeout")
.ok()
.flatten()
pub fn timeout_config(database: &Database) -> TimeoutConfig {
let timeout = database
.settings
.get_int(Setting::ApiTimeout)
.and_then(|i| i.try_into().ok())
.map_or(DEFAULT_TIMEOUT_DURATION, Duration::from_millis);

Expand All @@ -39,27 +41,31 @@ pub(crate) fn stalled_stream_protection_config() -> StalledStreamProtectionConfi
.build()
}

async fn base_sdk_config(region: Region, credentials_provider: impl ProvideCredentials + 'static) -> SdkConfig {
async fn base_sdk_config(
database: &Database,
region: Region,
credentials_provider: impl ProvideCredentials + 'static,
) -> SdkConfig {
aws_config::defaults(behavior_version())
.region(region)
.credentials_provider(credentials_provider)
.timeout_config(timeout_config())
.timeout_config(timeout_config(database))
.retry_config(RetryConfig::adaptive())
.load()
.await
}

pub(crate) async fn bearer_sdk_config(endpoint: &Endpoint) -> SdkConfig {
pub async fn bearer_sdk_config(database: &Database, endpoint: &Endpoint) -> SdkConfig {
let credentials = Credentials::new("xxx", "xxx", None, None, "xxx");
base_sdk_config(endpoint.region().clone(), credentials).await
base_sdk_config(database, endpoint.region().clone(), credentials).await
}

pub(crate) async fn sigv4_sdk_config(endpoint: &Endpoint) -> Result<SdkConfig, ApiClientError> {
pub async fn sigv4_sdk_config(database: &Database, endpoint: &Endpoint) -> Result<SdkConfig, ApiClientError> {
let credentials_chain = CredentialsChain::new().await;

if let Err(err) = credentials_chain.provide_credentials().await {
return Err(ApiClientError::Credentials(err));
};

Ok(base_sdk_config(endpoint.region().clone(), credentials_chain).await)
Ok(base_sdk_config(database, endpoint.region().clone(), credentials_chain).await)
}
81 changes: 39 additions & 42 deletions crates/chat-cli/src/api_client/clients/streaming_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ use crate::aws_common::{
UserAgentOverrideInterceptor,
app_name,
};
use crate::database::{
AuthProfile,
Database,
};

mod inner {
use std::sync::{
Expand All @@ -53,72 +57,63 @@ mod inner {
#[derive(Clone, Debug)]
pub struct StreamingClient {
inner: inner::Inner,
profile_arn: Option<String>,
profile: Option<AuthProfile>,
}

impl StreamingClient {
pub async fn new() -> Result<Self, ApiClientError> {
let client = if crate::util::system_info::in_cloudshell()
|| std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty())
{
Self::new_qdeveloper_client(&Endpoint::load_q()).await?
} else {
Self::new_codewhisperer_client(&Endpoint::load_codewhisperer()).await
};
Ok(client)
pub async fn new(database: &mut Database) -> Result<Self, ApiClientError> {
Ok(
if crate::util::system_info::in_cloudshell()
|| std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty())
{
Self::new_qdeveloper_client(database, &Endpoint::load_q(database)).await?
} else {
Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(database)).await?
},
)
}

pub fn mock(events: Vec<Vec<ChatResponseStream>>) -> Self {
Self {
inner: inner::Inner::Mock(Arc::new(Mutex::new(events.into_iter()))),
profile_arn: None,
profile: None,
}
}

pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self {
pub async fn new_codewhisperer_client(
database: &mut Database,
endpoint: &Endpoint,
) -> Result<Self, ApiClientError> {
let conf_builder: amzn_codewhisperer_streaming_client::config::Builder =
(&bearer_sdk_config(endpoint).await).into();
(&bearer_sdk_config(database, endpoint).await).into();
let conf = conf_builder
.http_client(crate::aws_common::http_client::client())
.interceptor(OptOutInterceptor::new())
.interceptor(OptOutInterceptor::new(database))
.interceptor(UserAgentOverrideInterceptor::new())
.bearer_token_resolver(BearerResolver)
.bearer_token_resolver(BearerResolver::new(database).await?)
.app_name(app_name())
.endpoint_url(endpoint.url())
.stalled_stream_protection(stalled_stream_protection_config())
.build();
let inner = inner::Inner::Codewhisperer(CodewhispererStreamingClient::from_conf(conf));

let profile_arn = match crate::settings::state::get_value("api.codewhisperer.profile") {
Ok(Some(profile)) => match profile.get("arn") {
Some(arn) => match arn.as_str() {
Some(arn) => Some(arn.to_string()),
None => {
error!("Stored arn is not a string. Instead it was: {arn}");
None
},
},
None => {
error!("Stored profile does not contain an arn. Instead it was: {profile}");
None
},
},
Ok(None) => None,
let profile = match database.get_auth_profile() {
Ok(profile) => profile,
Err(err) => {
error!("Failed to retrieve profile: {}", err);
error!("Failed to get auth profile: {err}");
None
},
};

Self { inner, profile_arn }
Ok(Self { inner, profile })
}

pub async fn new_qdeveloper_client(endpoint: &Endpoint) -> Result<Self, ApiClientError> {
pub async fn new_qdeveloper_client(database: &Database, endpoint: &Endpoint) -> Result<Self, ApiClientError> {
let conf_builder: amzn_qdeveloper_streaming_client::config::Builder =
(&sigv4_sdk_config(endpoint).await?).into();
(&sigv4_sdk_config(database, endpoint).await?).into();
let conf = conf_builder
.http_client(crate::aws_common::http_client::client())
.interceptor(OptOutInterceptor::new())
.interceptor(OptOutInterceptor::new(database))
.interceptor(UserAgentOverrideInterceptor::new())
.app_name(app_name())
.endpoint_url(endpoint.url())
Expand All @@ -127,7 +122,7 @@ impl StreamingClient {
let client = QDeveloperStreamingClient::from_conf(conf);
Ok(Self {
inner: inner::Inner::QDeveloper(client),
profile_arn: None,
profile: None,
})
}

Expand Down Expand Up @@ -162,7 +157,7 @@ impl StreamingClient {
let response = client
.generate_assistant_response()
.conversation_state(conversation_state)
.set_profile_arn(self.profile_arn.clone())
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()))
.send()
.await;

Expand Down Expand Up @@ -267,11 +262,12 @@ mod tests {

#[tokio::test]
async fn create_clients() {
let endpoint = Endpoint::load_codewhisperer();
let mut database = Database::new().await.unwrap();
let endpoint = Endpoint::load_codewhisperer(&database);

let _ = StreamingClient::new().await;
let _ = StreamingClient::new_codewhisperer_client(&endpoint).await;
let _ = StreamingClient::new_qdeveloper_client(&endpoint).await;
let _ = StreamingClient::new(&mut database).await;
let _ = StreamingClient::new_codewhisperer_client(&mut database, &endpoint).await;
let _ = StreamingClient::new_qdeveloper_client(&database, &endpoint).await;
}

#[tokio::test]
Expand Down Expand Up @@ -311,7 +307,8 @@ mod tests {
#[ignore]
#[tokio::test]
async fn assistant_response() {
let client = StreamingClient::new().await.unwrap();
let mut database = Database::new().await.unwrap();
let client = StreamingClient::new(&mut database).await.unwrap();
let mut response = client
.send_message(ConversationState {
conversation_id: None,
Expand Down
Loading
Loading