Skip to content

Commit 42b1676

Browse files
ManojtuguruManoj Tuguruomursalojoshrutkowskichaynabors
authored
feat: Enabling Sigv4 Auth Support for Q (#207)
* Initial Commit for Q CLI changes * Add SigV4 authentication support to login command * Adding with the Auth Strategy * Remove references for AuthStrategy * Added back the Q Dev client and Q_USE_SENDMESSAGE env variable * fix: Add origin to SendMessage request and change env variable for Sigv4 * code-cleanup:removed sigv4 auth references * chore: cleanup AuthStrategy definitions * fix/chore: skip login when using sigv4 and cleanup unused code * Fix/mcp disable (#257) * Update crates/chat-cli/src/api_client/clients/streaming_client.rs Co-authored-by: Chay Nabors <[email protected]> * Update crates/chat-cli/src/api_client/clients/streaming_client.rs Co-authored-by: Chay Nabors <[email protected]> * Update crates/chat-cli/src/auth/builder_id.rs Co-authored-by: Chay Nabors <[email protected]> * Chore:Addressing comments and Clean up * Refactored the builder ID condition check to AMAZON_Q_SIGV4 * removing unused imports * fixes --------- Co-authored-by: Manoj Tuguru <[email protected]> Co-authored-by: Olena Mursalova <[email protected]> Co-authored-by: Josh Rutkowski <[email protected]> Co-authored-by: Chay Nabors <[email protected]> Co-authored-by: Chay Nabors <[email protected]>
1 parent d8ea18f commit 42b1676

File tree

10 files changed

+229
-12
lines changed

10 files changed

+229
-12
lines changed

crates/chat-cli/src/api_client/clients/shared.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ use aws_credential_types::provider::ProvideCredentials;
88
use aws_types::SdkConfig;
99
use aws_types::sdk_config::StalledStreamProtectionConfig;
1010

11-
use crate::api_client::Endpoint;
11+
use crate::api_client::credentials::CredentialsChain;
12+
use crate::api_client::{
13+
ApiClientError,
14+
Endpoint,
15+
};
1216
use crate::aws_common::behavior_version;
1317
use crate::database::Database;
1418
use crate::database::settings::Setting;
@@ -55,3 +59,13 @@ pub async fn bearer_sdk_config(database: &Database, endpoint: &Endpoint) -> SdkC
5559
let credentials = Credentials::new("xxx", "xxx", None, None, "xxx");
5660
base_sdk_config(database, endpoint.region().clone(), credentials).await
5761
}
62+
63+
pub async fn sigv4_sdk_config(database: &Database, endpoint: &Endpoint) -> Result<SdkConfig, ApiClientError> {
64+
let credentials_chain = CredentialsChain::new().await;
65+
66+
if let Err(err) = credentials_chain.provide_credentials().await {
67+
return Err(ApiClientError::Credentials(err));
68+
};
69+
70+
Ok(base_sdk_config(database, endpoint.region().clone(), credentials_chain).await)
71+
}

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::sync::{
44
};
55

66
use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient;
7+
use amzn_qdeveloper_streaming_client::Client as QDeveloperStreamingClient;
8+
use amzn_qdeveloper_streaming_client::types::Origin;
79
use aws_types::request_id::RequestId;
810
use tracing::{
911
debug,
@@ -12,6 +14,7 @@ use tracing::{
1214

1315
use super::shared::{
1416
bearer_sdk_config,
17+
sigv4_sdk_config,
1518
stalled_stream_protection_config,
1619
};
1720
use crate::api_client::interceptor::opt_out::OptOutInterceptor;
@@ -40,12 +43,14 @@ mod inner {
4043
};
4144

4245
use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient;
46+
use amzn_qdeveloper_streaming_client::Client as QDeveloperStreamingClient;
4347

4448
use crate::api_client::model::ChatResponseStream;
4549

4650
#[derive(Clone, Debug)]
4751
pub enum Inner {
4852
Codewhisperer(CodewhispererStreamingClient),
53+
QDeveloper(QDeveloperStreamingClient),
4954
Mock(Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>),
5055
}
5156
}
@@ -58,7 +63,13 @@ pub struct StreamingClient {
5863

5964
impl StreamingClient {
6065
pub async fn new(database: &mut Database) -> Result<Self, ApiClientError> {
61-
Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(database)).await
66+
// If SIGV4_AUTH_ENABLED is true, use Q developer client
67+
if std::env::var("AMAZON_Q_SIGV4").is_ok_and(|v| !v.is_empty()) {
68+
Self::new_qdeveloper_client(database, &Endpoint::load_q(database)).await
69+
} else {
70+
// Default to CodeWhisperer client
71+
Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(database)).await
72+
}
6273
}
6374

6475
pub fn mock(events: Vec<Vec<ChatResponseStream>>) -> Self {
@@ -96,6 +107,25 @@ impl StreamingClient {
96107
Ok(Self { inner, profile })
97108
}
98109

110+
// Add SigV4 client creation method
111+
pub async fn new_qdeveloper_client(database: &Database, endpoint: &Endpoint) -> Result<Self, ApiClientError> {
112+
let conf_builder: amzn_qdeveloper_streaming_client::config::Builder =
113+
(&sigv4_sdk_config(database, endpoint).await?).into();
114+
let conf = conf_builder
115+
.http_client(crate::aws_common::http_client::client())
116+
.interceptor(OptOutInterceptor::new(database))
117+
.interceptor(UserAgentOverrideInterceptor::new())
118+
.app_name(app_name())
119+
.endpoint_url(endpoint.url())
120+
.stalled_stream_protection(stalled_stream_protection_config())
121+
.build();
122+
let client = QDeveloperStreamingClient::from_conf(conf);
123+
Ok(Self {
124+
inner: inner::Inner::QDeveloper(client),
125+
profile: None,
126+
})
127+
}
128+
99129
pub async fn send_message(&self, conversation: ConversationState) -> Result<SendMessageOutput, ApiClientError> {
100130
debug!("Sending conversation: {:#?}", conversation);
101131
let ConversationState {
@@ -180,6 +210,51 @@ impl StreamingClient {
180210
},
181211
}
182212
},
213+
inner::Inner::QDeveloper(client) => {
214+
let conversation_state = amzn_qdeveloper_streaming_client::types::ConversationState::builder()
215+
.set_conversation_id(conversation_id)
216+
.current_message(amzn_qdeveloper_streaming_client::types::ChatMessage::UserInputMessage(
217+
user_input_message.into(),
218+
))
219+
.chat_trigger_type(amzn_qdeveloper_streaming_client::types::ChatTriggerType::Manual)
220+
.set_history(
221+
history
222+
.map(|v| v.into_iter().map(|i| i.try_into()).collect::<Result<Vec<_>, _>>())
223+
.transpose()?,
224+
)
225+
.build()
226+
.expect("building conversation_state should not fail");
227+
228+
let response = client
229+
.send_message()
230+
.conversation_state(conversation_state)
231+
.set_source(Some(Origin::from("CLI")))
232+
.send()
233+
.await;
234+
235+
match response {
236+
Ok(resp) => Ok(SendMessageOutput::QDeveloper(resp)),
237+
Err(e) => {
238+
let status_code = e.raw_response().map(|res| res.status().as_u16());
239+
let is_quota_breach = e.raw_response().is_some_and(|resp| resp.status().as_u16() == 429);
240+
let is_context_window_overflow = e.as_service_error().is_some_and(|err| {
241+
matches!(err, err if err.meta().code() == Some("ValidationException")
242+
&& err.meta().message() == Some("Input is too long."))
243+
});
244+
245+
if is_quota_breach {
246+
Err(ApiClientError::QuotaBreach {
247+
message: "quota has reached its limit",
248+
status_code,
249+
})
250+
} else if is_context_window_overflow {
251+
Err(ApiClientError::ContextWindowOverflow { status_code })
252+
} else {
253+
Err(e.into())
254+
}
255+
},
256+
}
257+
},
183258
inner::Inner::Mock(events) => {
184259
let mut new_events = events.lock().unwrap().next().unwrap_or_default().clone();
185260
new_events.reverse();
@@ -194,13 +269,15 @@ pub enum SendMessageOutput {
194269
Codewhisperer(
195270
amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseOutput,
196271
),
272+
QDeveloper(amzn_qdeveloper_streaming_client::operation::send_message::SendMessageOutput),
197273
Mock(Vec<ChatResponseStream>),
198274
}
199275

200276
impl SendMessageOutput {
201277
pub fn request_id(&self) -> Option<&str> {
202278
match self {
203279
SendMessageOutput::Codewhisperer(output) => output.request_id(),
280+
SendMessageOutput::QDeveloper(output) => output.request_id(),
204281
SendMessageOutput::Mock(_) => None,
205282
}
206283
}
@@ -212,6 +289,7 @@ impl SendMessageOutput {
212289
.recv()
213290
.await?
214291
.map(|s| s.into())),
292+
SendMessageOutput::QDeveloper(output) => Ok(output.send_message_response.recv().await?.map(|s| s.into())),
215293
SendMessageOutput::Mock(vec) => Ok(vec.pop()),
216294
}
217295
}
@@ -221,6 +299,7 @@ impl RequestId for SendMessageOutput {
221299
fn request_id(&self) -> Option<&str> {
222300
match self {
223301
SendMessageOutput::Codewhisperer(output) => output.request_id(),
302+
SendMessageOutput::QDeveloper(output) => output.request_id(),
224303
SendMessageOutput::Mock(_) => Some("<mock-request-id>"),
225304
}
226305
}
@@ -242,6 +321,7 @@ mod tests {
242321

243322
let _ = StreamingClient::new(&mut database).await;
244323
let _ = StreamingClient::new_codewhisperer_client(&mut database, &endpoint).await;
324+
let _ = StreamingClient::new_qdeveloper_client(&database, &endpoint).await;
245325
}
246326

247327
#[tokio::test]

crates/chat-cli/src/api_client/consts.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use aws_config::Region;
33
// Endpoint constants
44
pub const PROD_CODEWHISPERER_ENDPOINT_URL: &str = "https://codewhisperer.us-east-1.amazonaws.com";
55
pub const PROD_CODEWHISPERER_ENDPOINT_REGION: Region = Region::from_static("us-east-1");
6+
pub const PROD_Q_ENDPOINT_URL: &str = "https://q.us-east-1.amazonaws.com";
7+
pub const PROD_Q_ENDPOINT_REGION: Region = Region::from_static("us-east-1");
68

79
// FRA endpoint constants
810
pub const PROD_CODEWHISPERER_FRA_ENDPOINT_URL: &str = "https://q.eu-central-1.amazonaws.com/";
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use aws_config::default_provider::region::DefaultRegionChain;
2+
use aws_config::ecs::EcsCredentialsProvider;
3+
use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider;
4+
use aws_config::imds::credentials::ImdsCredentialsProvider;
5+
use aws_config::meta::credentials::CredentialsProviderChain;
6+
use aws_config::profile::ProfileFileCredentialsProvider;
7+
use aws_config::provider_config::ProviderConfig;
8+
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
9+
use aws_credential_types::Credentials;
10+
use aws_credential_types::provider::{
11+
self,
12+
ProvideCredentials,
13+
future,
14+
};
15+
use tracing::Instrument;
16+
17+
#[derive(Debug)]
18+
pub struct CredentialsChain {
19+
provider_chain: CredentialsProviderChain,
20+
}
21+
22+
impl CredentialsChain {
23+
/// Based on code the code for
24+
/// [aws_config::default_provider::credentials::DefaultCredentialsChain]
25+
pub async fn new() -> Self {
26+
let region = DefaultRegionChain::builder().build().region().await;
27+
let config = ProviderConfig::default().with_region(region.clone());
28+
29+
let env_provider = EnvironmentVariableCredentialsProvider::new();
30+
let profile_provider = ProfileFileCredentialsProvider::builder().configure(&config).build();
31+
let web_identity_token_provider = WebIdentityTokenCredentialsProvider::builder()
32+
.configure(&config)
33+
.build();
34+
let imds_provider = ImdsCredentialsProvider::builder().configure(&config).build();
35+
let ecs_provider = EcsCredentialsProvider::builder().configure(&config).build();
36+
37+
let mut provider_chain = CredentialsProviderChain::first_try("Environment", env_provider);
38+
39+
provider_chain = provider_chain
40+
.or_else("Profile", profile_provider)
41+
.or_else("WebIdentityToken", web_identity_token_provider)
42+
.or_else("EcsContainer", ecs_provider)
43+
.or_else("Ec2InstanceMetadata", imds_provider);
44+
45+
CredentialsChain { provider_chain }
46+
}
47+
48+
async fn credentials(&self) -> provider::Result {
49+
self.provider_chain
50+
.provide_credentials()
51+
.instrument(tracing::debug_span!("provide_credentials", provider = %"default_chain"))
52+
.await
53+
}
54+
}
55+
56+
impl ProvideCredentials for CredentialsChain {
57+
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
58+
where
59+
Self: 'a,
60+
{
61+
future::ProvideCredentials::new(self.credentials())
62+
}
63+
64+
fn fallback_on_interrupt(&self) -> Option<Credentials> {
65+
self.provider_chain.fallback_on_interrupt()
66+
}
67+
}
68+
69+
#[cfg(test)]
70+
mod tests {
71+
use super::*;
72+
73+
#[tokio::test]
74+
async fn test_credentials_chain() {
75+
let credentials_chain = CredentialsChain::new().await;
76+
let credentials_res = credentials_chain.provide_credentials().await;
77+
let fallback_on_interrupt_res = credentials_chain.fallback_on_interrupt();
78+
println!("credentials_res: {credentials_res:?}, fallback_on_interrupt_res: {fallback_on_interrupt_res:?}");
79+
}
80+
}

crates/chat-cli/src/api_client/endpoints.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use crate::api_client::consts::{
99
PROD_CODEWHISPERER_ENDPOINT_URL,
1010
PROD_CODEWHISPERER_FRA_ENDPOINT_REGION,
1111
PROD_CODEWHISPERER_FRA_ENDPOINT_URL,
12+
PROD_Q_ENDPOINT_REGION,
13+
PROD_Q_ENDPOINT_URL,
1214
};
1315
use crate::database::Database;
1416
use crate::database::settings::Setting;
@@ -28,6 +30,10 @@ impl Endpoint {
2830
url: Cow::Borrowed(PROD_CODEWHISPERER_ENDPOINT_URL),
2931
region: PROD_CODEWHISPERER_ENDPOINT_REGION,
3032
};
33+
pub const Q_ENDPOINT: Self = Self {
34+
url: Cow::Borrowed(PROD_Q_ENDPOINT_URL),
35+
region: PROD_Q_ENDPOINT_REGION,
36+
};
3137

3238
pub fn load_codewhisperer(database: &Database) -> Self {
3339
let (endpoint, region) = if let Some(Value::Object(o)) = database.settings.get(Setting::ApiCodeWhispererService)
@@ -63,6 +69,24 @@ impl Endpoint {
6369
}
6470
}
6571

72+
pub fn load_q(database: &Database) -> Self {
73+
match database.settings.get(Setting::ApiQService) {
74+
Some(Value::Object(o)) => {
75+
let endpoint = o.get("endpoint").and_then(|v| v.as_str());
76+
let region = o.get("region").and_then(|v| v.as_str());
77+
78+
match (endpoint, region) {
79+
(Some(endpoint), Some(region)) => Self {
80+
url: endpoint.to_owned().into(),
81+
region: Region::new(region.to_owned()),
82+
},
83+
_ => Endpoint::Q_ENDPOINT,
84+
}
85+
},
86+
_ => Endpoint::Q_ENDPOINT,
87+
}
88+
}
89+
6690
pub(crate) fn url(&self) -> &str {
6791
&self.url
6892
}
@@ -82,6 +106,7 @@ mod tests {
82106
async fn test_endpoints() {
83107
let database = Database::new().await.unwrap();
84108
let _ = Endpoint::load_codewhisperer(&database);
109+
let _ = Endpoint::load_q(&database);
85110

86111
let prod = &Endpoint::DEFAULT_ENDPOINT;
87112
Url::parse(prod.url()).unwrap();

crates/chat-cli/src/api_client/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use amzn_consolas_client::operation::generate_recommendations::GenerateRecommend
99
use amzn_consolas_client::operation::list_customizations::ListCustomizationsError;
1010
use amzn_qdeveloper_streaming_client::operation::send_message::SendMessageError as QDeveloperSendMessageError;
1111
use amzn_qdeveloper_streaming_client::types::error::ChatResponseStreamError as QDeveloperChatResponseStreamError;
12+
use aws_credential_types::provider::error::CredentialsError;
1213
use aws_sdk_ssooidc::error::ProvideErrorMetadata;
1314
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
1415
pub use aws_smithy_runtime_api::client::result::SdkError;
@@ -88,6 +89,10 @@ pub enum ApiClientError {
8889
request_id: Option<String>,
8990
status_code: Option<u16>,
9091
},
92+
93+
// Credential errors
94+
#[error("failed to load credentials: {}", .0)]
95+
Credentials(CredentialsError),
9196
}
9297

9398
impl ApiClientError {
@@ -110,6 +115,7 @@ impl ApiClientError {
110115
ApiClientError::AuthError(_) => None,
111116
ApiClientError::ModelOverloadedError { status_code, .. } => *status_code,
112117
ApiClientError::MonthlyLimitReached { status_code } => *status_code,
118+
ApiClientError::Credentials(_e) => None,
113119
}
114120
}
115121
}
@@ -134,6 +140,7 @@ impl ReasonCode for ApiClientError {
134140
ApiClientError::AuthError(_) => "AuthError".to_string(),
135141
ApiClientError::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
136142
ApiClientError::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
143+
ApiClientError::Credentials(_) => "CredentialsError".to_string(),
137144
}
138145
}
139146
}
@@ -171,6 +178,7 @@ mod tests {
171178

172179
fn all_errors() -> Vec<ApiClientError> {
173180
vec![
181+
ApiClientError::Credentials(CredentialsError::unhandled("<unhandled>")),
174182
ApiClientError::GenerateCompletions(SdkError::service_error(
175183
GenerateCompletionsError::unhandled("<unhandled>"),
176184
response(),

crates/chat-cli/src/api_client/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
pub mod clients;
22
pub(crate) mod consts;
3+
pub(crate) mod credentials;
34
pub mod customization;
45
mod endpoints;
56
mod error;
67
pub(crate) mod interceptor;
78
pub mod model;
89
pub mod profile;
910

10-
pub use clients::{
11-
Client,
12-
StreamingClient,
13-
};
11+
pub use clients::Client;
1412
pub use endpoints::Endpoint;
1513
pub use error::ApiClientError;
1614
pub use profile::list_available_profiles;

0 commit comments

Comments
 (0)