Skip to content

Commit 0a3c44b

Browse files
committed
finish rewrite
1 parent 0a86079 commit 0a3c44b

File tree

24 files changed

+276
-453
lines changed

24 files changed

+276
-453
lines changed

crates/chat-cli/src/api_client/clients/client.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::api_client::{
1212
ApiClientError,
1313
Endpoint,
1414
};
15+
use crate::auth::AuthError;
1516
use crate::auth::builder_id::BearerResolver;
1617
use crate::aws_common::{
1718
UserAgentOverrideInterceptor,
@@ -40,22 +41,22 @@ pub struct Client {
4041
}
4142

4243
impl Client {
43-
pub async fn new(database: &mut Database, endpoint: Option<Endpoint>) -> Client {
44+
pub async fn new(database: &mut Database, endpoint: Option<Endpoint>) -> Result<Client, AuthError> {
4445
if cfg!(test) {
45-
return Self {
46+
return Ok(Self {
4647
inner: inner::Inner::Mock,
4748
profile: None,
48-
};
49+
});
4950
}
5051

5152
let endpoint = endpoint.unwrap_or(Endpoint::load_codewhisperer(database));
5253
let conf_builder: amzn_codewhisperer_client::config::Builder =
5354
(&bearer_sdk_config(database, &endpoint).await).into();
5455
let conf = conf_builder
55-
.http_client(crate::aws_common::http_client::client(database))
56+
.http_client(crate::aws_common::http_client::client())
5657
.interceptor(OptOutInterceptor::new(database))
5758
.interceptor(UserAgentOverrideInterceptor::new())
58-
.bearer_token_resolver(BearerResolver)
59+
.bearer_token_resolver(BearerResolver::new(database).await?)
5960
.app_name(app_name())
6061
.endpoint_url(endpoint.url())
6162
.build();
@@ -70,7 +71,7 @@ impl Client {
7071
},
7172
};
7273

73-
Self { inner, profile }
74+
Ok(Self { inner, profile })
7475
}
7576

7677
pub async fn send_telemetry_event(
@@ -142,7 +143,7 @@ mod tests {
142143
#[tokio::test]
143144
async fn test_mock() {
144145
let mut database = crate::database::Database::new().await.unwrap();
145-
let client = Client::new(&mut database, None).await;
146+
let client = Client::new(&mut database, None).await.unwrap();
146147
client
147148
.send_telemetry_event(
148149
TelemetryEvent::ChatAddMessageEvent(

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ impl StreamingClient {
6767
if crate::util::system_info::in_cloudshell()
6868
|| std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty())
6969
{
70-
Self::new_qdeveloper_client(database, &Endpoint::load_q()).await?
70+
Self::new_qdeveloper_client(database, &Endpoint::load_q(database)).await?
7171
} else {
72-
Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(&database)).await
72+
Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(database)).await?
7373
},
7474
)
7575
}
@@ -81,14 +81,17 @@ impl StreamingClient {
8181
}
8282
}
8383

84-
pub async fn new_codewhisperer_client(database: &mut Database, endpoint: &Endpoint) -> Self {
84+
pub async fn new_codewhisperer_client(
85+
database: &mut Database,
86+
endpoint: &Endpoint,
87+
) -> Result<Self, ApiClientError> {
8588
let conf_builder: amzn_codewhisperer_streaming_client::config::Builder =
8689
(&bearer_sdk_config(database, endpoint).await).into();
8790
let conf = conf_builder
88-
.http_client(crate::aws_common::http_client::client(database))
91+
.http_client(crate::aws_common::http_client::client())
8992
.interceptor(OptOutInterceptor::new(database))
9093
.interceptor(UserAgentOverrideInterceptor::new())
91-
.bearer_token_resolver(BearerResolver)
94+
.bearer_token_resolver(BearerResolver::new(database).await?)
9295
.app_name(app_name())
9396
.endpoint_url(endpoint.url())
9497
.stalled_stream_protection(stalled_stream_protection_config())
@@ -103,14 +106,14 @@ impl StreamingClient {
103106
},
104107
};
105108

106-
Self { inner, profile }
109+
Ok(Self { inner, profile })
107110
}
108111

109112
pub async fn new_qdeveloper_client(database: &Database, endpoint: &Endpoint) -> Result<Self, ApiClientError> {
110113
let conf_builder: amzn_qdeveloper_streaming_client::config::Builder =
111114
(&sigv4_sdk_config(database, endpoint).await?).into();
112115
let conf = conf_builder
113-
.http_client(crate::aws_common::http_client::client(database))
116+
.http_client(crate::aws_common::http_client::client())
114117
.interceptor(OptOutInterceptor::new(database))
115118
.interceptor(UserAgentOverrideInterceptor::new())
116119
.app_name(app_name())
@@ -265,7 +268,7 @@ mod tests {
265268

266269
let _ = StreamingClient::new(&mut database).await;
267270
let _ = StreamingClient::new_codewhisperer_client(&mut database, &endpoint).await;
268-
let _ = StreamingClient::new_qdeveloper_client(&mut database, &endpoint).await;
271+
let _ = StreamingClient::new_qdeveloper_client(&database, &endpoint).await;
269272
}
270273

271274
#[tokio::test]

crates/chat-cli/src/api_client/endpoints.rs

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use crate::api_client::consts::{
1313
PROD_Q_ENDPOINT_URL,
1414
};
1515
use crate::database::Database;
16+
use crate::database::settings::Setting;
17+
use crate::database::state::StateDatabase;
1618

1719
#[derive(Debug, Clone, PartialEq, Eq)]
1820
pub struct Endpoint {
@@ -35,36 +37,29 @@ impl Endpoint {
3537
};
3638

3739
pub fn load_codewhisperer(database: &Database) -> Self {
38-
let (endpoint, region) =
39-
if let Ok(Some(Value::Object(o))) = database.settings.get_value("api.codewhisperer.service") {
40-
// The following branch is evaluated in case the user has set their own endpoint.
41-
(
42-
o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()),
43-
o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()),
44-
)
45-
} else if let Ok(Some(Value::Object(o))) =
46-
crate::database::persistent_state::get_value("api.codewhisperer.profile")
40+
let (endpoint, region) = if let Some(Value::Object(o)) = database.settings.get(Setting::ApiCodeWhispererService)
41+
{
42+
// The following branch is evaluated in case the user has set their own endpoint.
43+
(
44+
o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()),
45+
o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()),
46+
)
47+
} else if let Ok(Some(profile)) = database.get_auth_profile() {
48+
// The following branch is evaluated in the case of user profile being set.
49+
let region = profile.arn.split(':').nth(3).unwrap_or_default().to_owned();
50+
match Self::CODEWHISPERER_ENDPOINTS
51+
.iter()
52+
.find(|e| e.region().as_ref() == region)
4753
{
48-
// The following branch is evaluated in the case of user profile being set.
49-
match o.get("arn").and_then(|v| v.as_str()).map(|v| v.to_owned()) {
50-
Some(arn) => {
51-
let region = arn.split(':').nth(3).unwrap_or_default().to_owned();
52-
match Self::CODEWHISPERER_ENDPOINTS
53-
.iter()
54-
.find(|e| e.region().as_ref() == region)
55-
{
56-
Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)),
57-
None => {
58-
error!("Failed to find endpoint for region: {region}");
59-
(None, None)
60-
},
61-
}
62-
},
63-
None => (None, None),
64-
}
65-
} else {
66-
(None, None)
67-
};
54+
Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)),
55+
None => {
56+
error!("Failed to find endpoint for region: {region}");
57+
(None, None)
58+
},
59+
}
60+
} else {
61+
(None, None)
62+
};
6863

6964
match (endpoint, region) {
7065
(Some(endpoint), Some(region)) => Self {
@@ -75,9 +70,9 @@ impl Endpoint {
7570
}
7671
}
7772

78-
pub fn load_q() -> Self {
79-
match crate::database::settings::get_value("api.q.service") {
80-
Ok(Some(Value::Object(o))) => {
73+
pub fn load_q(database: &Database) -> Self {
74+
match database.settings.get(Setting::ApiQService) {
75+
Some(Value::Object(o)) => {
8176
let endpoint = o.get("endpoint").and_then(|v| v.as_str());
8277
let region = o.get("region").and_then(|v| v.as_str());
8378

@@ -108,10 +103,11 @@ mod tests {
108103

109104
use super::*;
110105

111-
#[test]
112-
fn test_endpoints() {
113-
let _ = Endpoint::load_codewhisperer();
114-
let _ = Endpoint::load_q();
106+
#[tokio::test]
107+
async fn test_endpoints() {
108+
let database = Database::new().await.unwrap();
109+
let _ = Endpoint::load_codewhisperer(&database);
110+
let _ = Endpoint::load_q(&database);
115111

116112
let prod = &Endpoint::DEFAULT_ENDPOINT;
117113
Url::parse(prod.url()).unwrap();

crates/chat-cli/src/api_client/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use aws_smithy_runtime_api::client::result::SdkError;
1313
use aws_smithy_types::event_stream::RawMessage;
1414
use thiserror::Error;
1515

16+
use crate::auth::AuthError;
1617
use crate::aws_common::SdkErrorDisplay;
1718

1819
#[derive(Debug, Error)]
@@ -61,6 +62,9 @@ pub enum ApiClientError {
6162

6263
#[error(transparent)]
6364
ListAvailableProfilesError(#[from] SdkError<ListAvailableProfilesError, HttpResponse>),
65+
66+
#[error(transparent)]
67+
AuthError(#[from] AuthError),
6468
}
6569

6670
#[cfg(test)]
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
use crate::api_client::Client;
22
use crate::api_client::endpoints::Endpoint;
3+
use crate::auth::AuthError;
34
use crate::database::Database;
45
use crate::database::state::AuthProfile;
56

6-
pub async fn list_available_profiles(database: &mut Database) -> Vec<AuthProfile> {
7+
pub async fn list_available_profiles(database: &mut Database) -> Result<Vec<AuthProfile>, AuthError> {
78
let mut profiles = vec![];
89
for endpoint in Endpoint::CODEWHISPERER_ENDPOINTS {
9-
let client = Client::new(database, Some(endpoint.clone())).await;
10+
let client = Client::new(database, Some(endpoint.clone())).await?;
1011
match client.list_available_profiles().await {
1112
Ok(mut p) => profiles.append(&mut p),
1213
Err(e) => tracing::error!("Failed to list profiles from endpoint {:?}: {:?}", endpoint, e),
1314
}
1415
}
1516

16-
profiles
17+
Ok(profiles)
1718
}

0 commit comments

Comments
 (0)