Skip to content

Commit 7d3d435

Browse files
authored
feat: q builder id pro subscription flow (#143)
1 parent d09f482 commit 7d3d435

File tree

8 files changed

+433
-2
lines changed

8 files changed

+433
-2
lines changed

crates/chat-cli/src/api_client/clients/client.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use amzn_codewhisperer_client::Client as CodewhispererClient;
2+
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
23
use amzn_codewhisperer_client::types::{
34
OptOutPreference,
5+
SubscriptionStatus,
46
TelemetryEvent,
57
UserContext,
68
};
@@ -123,6 +125,21 @@ impl Client {
123125
]),
124126
}
125127
}
128+
129+
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
130+
match &self.inner {
131+
inner::Inner::Codewhisperer(client) => client
132+
.create_subscription_token()
133+
.send()
134+
.await
135+
.map_err(ApiClientError::CreateSubscriptionToken),
136+
inner::Inner::Mock => Ok(CreateSubscriptionTokenOutput::builder()
137+
.set_encoded_verification_url(Some("test/url".to_string()))
138+
.set_status(Some(SubscriptionStatus::Inactive))
139+
.set_token(Some("test-token".to_string()))
140+
.build()?),
141+
}
142+
}
126143
}
127144

128145
#[cfg(test)]

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ impl StreamingClient {
147147
err.meta().message()
148148
== Some("Encountered unexpectedly high load when processing the request, please try again.")
149149
});
150+
let is_monthly_limit_err = e
151+
.raw_response()
152+
.and_then(|resp| resp.body().bytes())
153+
.and_then(|bytes| match String::from_utf8(bytes.to_vec()) {
154+
Ok(s) => Some(s.contains("MONTHLY_REQUEST_COUNT")),
155+
Err(_) => None,
156+
})
157+
.unwrap_or(false);
158+
150159
if is_quota_breach {
151160
Err(ApiClientError::QuotaBreach("quota has reached its limit"))
152161
} else if is_context_window_overflow {
@@ -157,6 +166,8 @@ impl StreamingClient {
157166
.and_then(|err| err.meta().request_id())
158167
.map(|s| s.to_string());
159168
Err(ApiClientError::ModelOverloadedError { request_id })
169+
} else if is_monthly_limit_err {
170+
Err(ApiClientError::MonthlyLimitReached)
160171
} else {
161172
Err(e.into())
162173
}

crates/chat-cli/src/api_client/error.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
12
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
23
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
34
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
@@ -52,6 +53,13 @@ pub enum ApiClientError {
5253
#[error("quota has reached its limit")]
5354
QuotaBreach(&'static str),
5455

56+
// Separate from quota breach (somehow)
57+
#[error("monthly query limit reached")]
58+
MonthlyLimitReached,
59+
60+
#[error("{}", SdkErrorDisplay(.0))]
61+
CreateSubscriptionToken(#[from] SdkError<CreateSubscriptionTokenError, HttpResponse>),
62+
5563
/// Returned from the backend when the user input is too large to fit within the model context
5664
/// window.
5765
///
@@ -88,11 +96,13 @@ impl ReasonCode for ApiClientError {
8896
ApiClientError::QDeveloperChatResponseStream(e) => sdk_error_code(e),
8997
ApiClientError::ListAvailableProfilesError(e) => sdk_error_code(e),
9098
ApiClientError::SendTelemetryEvent(e) => sdk_error_code(e),
99+
ApiClientError::CreateSubscriptionToken(e) => sdk_error_code(e),
91100
ApiClientError::QuotaBreach(_) => "QuotaBreachError".to_string(),
92101
ApiClientError::ContextWindowOverflow => "ContextWindowOverflow".to_string(),
93102
ApiClientError::SmithyBuild(_) => "SmithyBuildError".to_string(),
94103
ApiClientError::AuthError(_) => "AuthError".to_string(),
95104
ApiClientError::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
105+
ApiClientError::MonthlyLimitReached => "MonthlyLimitReached".to_string(),
96106
}
97107
}
98108
}
@@ -150,6 +160,10 @@ mod tests {
150160
QDeveloperSendMessageError::unhandled("<unhandled>"),
151161
response(),
152162
)),
163+
ApiClientError::CreateSubscriptionToken(SdkError::service_error(
164+
CreateSubscriptionTokenError::unhandled("<unhandled>"),
165+
response(),
166+
)),
153167
ApiClientError::CodewhispererChatResponseStream(SdkError::service_error(
154168
CodewhispererChatResponseStreamError::unhandled("<unhandled>"),
155169
raw_message(),

crates/chat-cli/src/auth/builder_id.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ use aws_smithy_runtime_api::client::identity::{
4141
};
4242
use aws_smithy_types::error::display::DisplayErrorContext;
4343
use aws_types::region::Region;
44+
use eyre::{
45+
Result,
46+
eyre,
47+
};
4448
use time::OffsetDateTime;
4549
use tracing::{
4650
debug,
@@ -568,6 +572,17 @@ impl ResolveIdentity for BearerResolver {
568572
}
569573
}
570574

575+
pub async fn is_idc_user(database: &Database) -> Result<bool> {
576+
if cfg!(test) {
577+
return Ok(false);
578+
}
579+
if let Ok(Some(token)) = BuilderIdToken::load(database).await {
580+
Ok(token.token_type() == TokenType::IamIdentityCenter)
581+
} else {
582+
Err(eyre!("No auth token found - is the user signed in?"))
583+
}
584+
}
585+
571586
#[cfg(test)]
572587
mod tests {
573588
use super::*;

crates/chat-cli/src/cli/chat/command.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ pub enum Command {
6060
},
6161
Mcp,
6262
Model,
63+
Subscribe {
64+
manage: bool,
65+
},
6366
}
6467

6568
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -841,6 +844,10 @@ impl Command {
841844
},
842845
"mcp" => Self::Mcp,
843846
"model" => Self::Model,
847+
"subscribe" => {
848+
let manage = parts.contains(&"--manage");
849+
Self::Subscribe { manage }
850+
},
844851
unknown_command => {
845852
let looks_like_path = {
846853
let after_slash_command_str = parts[1..].join(" ");

0 commit comments

Comments
 (0)