Skip to content

Commit 81fb827

Browse files
committed
Merge branch 'main' into background-server-load
2 parents 9d3d968 + aa146bf commit 81fb827

File tree

9 files changed

+107
-344
lines changed

9 files changed

+107
-344
lines changed

crates/chat-cli/src/auth/builder_id.rs

Lines changed: 30 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -52,49 +52,19 @@ use crate::auth::AuthError;
5252
use crate::auth::consts::*;
5353
use crate::auth::scope::is_scopes;
5454
use crate::aws_common::app_name;
55-
use crate::database::Database;
56-
use crate::database::secret_store::{
55+
use crate::database::{
56+
Database,
5757
Secret,
58-
SecretStore,
5958
};
6059

61-
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
60+
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6261
pub enum OAuthFlow {
6362
DeviceCode,
6463
// This must remain backwards compatible
6564
#[serde(alias = "PKCE")]
6665
Pkce,
6766
}
6867

69-
// Implement Serialize manually to ensure proper serialization
70-
impl serde::Serialize for OAuthFlow {
71-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72-
where
73-
S: serde::Serializer,
74-
{
75-
match *self {
76-
OAuthFlow::DeviceCode => serializer.serialize_str("DeviceCode"),
77-
OAuthFlow::Pkce => serialize_pkce(serializer),
78-
}
79-
}
80-
}
81-
82-
fn serialize_pkce<S>(serializer: S) -> Result<S::Ok, S::Error>
83-
where
84-
S: serde::Serializer,
85-
{
86-
serializer.serialize_str("PKCE")
87-
}
88-
89-
impl std::fmt::Display for OAuthFlow {
90-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91-
match *self {
92-
OAuthFlow::DeviceCode => write!(f, "DeviceCode"),
93-
OAuthFlow::Pkce => write!(f, "PKCE"),
94-
}
95-
}
96-
}
97-
9868
/// Indicates if an expiration time has passed, there is a small 1 min window that is removed
9969
/// so the token will not expire in transit
10070
fn is_expired(expiration_time: &OffsetDateTime) -> bool {
@@ -152,8 +122,8 @@ impl DeviceRegistration {
152122
}
153123

154124
/// Loads the OIDC registered client from the secret store, deleting it if it is expired.
155-
async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result<Option<Self>, AuthError> {
156-
let device_registration = secret_store.get(Self::SECRET_KEY).await?;
125+
async fn load_from_secret_store(database: &Database, region: &Region) -> Result<Option<Self>, AuthError> {
126+
let device_registration = database.get_secret(Self::SECRET_KEY).await?;
157127

158128
if let Some(device_registration) = device_registration {
159129
// check that the data is not expired, assume it is invalid if not present
@@ -167,7 +137,7 @@ impl DeviceRegistration {
167137
}
168138

169139
// delete the data if its expired or invalid
170-
if let Err(err) = secret_store.delete(Self::SECRET_KEY).await {
140+
if let Err(err) = database.delete_secret(Self::SECRET_KEY).await {
171141
error!(?err, "Failed to delete device registration from keychain");
172142
}
173143

@@ -181,7 +151,7 @@ impl DeviceRegistration {
181151
client: &Client,
182152
region: &Region,
183153
) -> Result<Self, AuthError> {
184-
match Self::load_from_secret_store(&database.secret_store, region).await {
154+
match Self::load_from_secret_store(database, region).await {
185155
Ok(Some(registration)) if registration.oauth_flow == OAuthFlow::DeviceCode => match &registration.scopes {
186156
Some(scopes) if is_scopes(scopes) => return Ok(registration),
187157
_ => warn!("Invalid scopes in device registration, ignoring"),
@@ -210,17 +180,17 @@ impl DeviceRegistration {
210180
SCOPES.iter().map(|s| (*s).to_owned()).collect(),
211181
);
212182

213-
if let Err(err) = device_registration.save(&database.secret_store).await {
183+
if let Err(err) = device_registration.save(database).await {
214184
error!(?err, "Failed to write device registration to keychain");
215185
}
216186

217187
Ok(device_registration)
218188
}
219189

220190
/// Saves to the passed secret store.
221-
pub async fn save(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
191+
pub async fn save(&self, secret_store: &Database) -> Result<(), AuthError> {
222192
secret_store
223-
.set(Self::SECRET_KEY, &serde_json::to_string(&self)?)
193+
.set_secret(Self::SECRET_KEY, &serde_json::to_string(&self)?)
224194
.await?;
225195
Ok(())
226196
}
@@ -314,8 +284,8 @@ impl BuilderIdToken {
314284
}
315285

316286
/// Load the token from the keychain, refresh the token if it is expired and return it
317-
pub async fn load(database: &mut Database) -> Result<Option<Self>, AuthError> {
318-
match database.secret_store.get(Self::SECRET_KEY).await {
287+
pub async fn load(database: &Database) -> Result<Option<Self>, AuthError> {
288+
match database.get_secret(Self::SECRET_KEY).await {
319289
Ok(Some(secret)) => {
320290
let token: Option<Self> = serde_json::from_str(&secret.0)?;
321291
match token {
@@ -325,7 +295,7 @@ impl BuilderIdToken {
325295
let client = client(region.clone());
326296
// if token is expired try to refresh
327297
if token.is_expired() {
328-
token.refresh_token(&client, &database.secret_store, &region).await
298+
token.refresh_token(&client, database, &region).await
329299
} else {
330300
Ok(Some(token))
331301
}
@@ -345,19 +315,19 @@ impl BuilderIdToken {
345315
pub async fn refresh_token(
346316
&self,
347317
client: &Client,
348-
secret_store: &SecretStore,
318+
database: &Database,
349319
region: &Region,
350320
) -> Result<Option<Self>, AuthError> {
351321
let Some(refresh_token) = &self.refresh_token else {
352322
// if the token is expired and has no refresh token, delete it
353-
if let Err(err) = self.delete(secret_store).await {
323+
if let Err(err) = self.delete(database).await {
354324
error!(?err, "Failed to delete builder id token");
355325
}
356326

357327
return Ok(None);
358328
};
359329

360-
let registration = match DeviceRegistration::load_from_secret_store(secret_store, region).await? {
330+
let registration = match DeviceRegistration::load_from_secret_store(database, region).await? {
361331
Some(registration) if registration.oauth_flow == self.oauth_flow => registration,
362332
// If the OIDC client registration is for a different oauth flow or doesn't exist, then
363333
// we can't refresh the token.
@@ -394,7 +364,7 @@ impl BuilderIdToken {
394364
);
395365
debug!("Refreshed access token, new token: {:?}", token);
396366

397-
if let Err(err) = token.save(secret_store).await {
367+
if let Err(err) = token.save(database).await {
398368
error!(?err, "Failed to store builder id access token");
399369
};
400370

@@ -407,7 +377,7 @@ impl BuilderIdToken {
407377
// if the error is the client's fault, clear the token
408378
if let SdkError::ServiceError(service_err) = &err {
409379
if !service_err.err().is_slow_down_exception() {
410-
if let Err(err) = self.delete(secret_store).await {
380+
if let Err(err) = self.delete(database).await {
411381
error!(?err, "Failed to delete builder id token");
412382
}
413383
}
@@ -427,16 +397,16 @@ impl BuilderIdToken {
427397
}
428398

429399
/// Save the token to the keychain
430-
pub async fn save(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
431-
secret_store
432-
.set(Self::SECRET_KEY, &serde_json::to_string(self)?)
400+
pub async fn save(&self, database: &Database) -> Result<(), AuthError> {
401+
database
402+
.set_secret(Self::SECRET_KEY, &serde_json::to_string(self)?)
433403
.await?;
434404
Ok(())
435405
}
436406

437407
/// Delete the token from the keychain
438-
pub async fn delete(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
439-
secret_store.delete(Self::SECRET_KEY).await?;
408+
pub async fn delete(&self, database: &Database) -> Result<(), AuthError> {
409+
database.delete_secret(Self::SECRET_KEY).await?;
440410
Ok(())
441411
}
442412

@@ -508,7 +478,7 @@ pub async fn poll_create_token(
508478
let token: BuilderIdToken =
509479
BuilderIdToken::from_output(output, region, start_url, OAuthFlow::DeviceCode, scopes);
510480

511-
if let Err(err) = token.save(&database.secret_store).await {
481+
if let Err(err) = token.save(database).await {
512482
error!(?err, "Failed to store builder id token");
513483
};
514484

@@ -529,13 +499,13 @@ pub async fn is_logged_in(database: &mut Database) -> bool {
529499
}
530500

531501
pub async fn logout(database: &mut Database) -> Result<(), AuthError> {
532-
let Ok(secret_store) = SecretStore::new().await else {
502+
let Ok(secret_store) = Database::new().await else {
533503
return Ok(());
534504
};
535505

536506
let (builder_res, device_res) = tokio::join!(
537-
secret_store.delete(BuilderIdToken::SECRET_KEY),
538-
secret_store.delete(DeviceRegistration::SECRET_KEY),
507+
secret_store.delete_secret(BuilderIdToken::SECRET_KEY),
508+
secret_store.delete_secret(DeviceRegistration::SECRET_KEY),
539509
);
540510

541511
let profile_res = database.unset_auth_profile();
@@ -585,20 +555,10 @@ mod tests {
585555
const US_EAST_1: Region = Region::from_static("us-east-1");
586556
const US_WEST_2: Region = Region::from_static("us-west-2");
587557

588-
macro_rules! test_ser_deser {
589-
($ty:ident, $variant:expr, $text:expr) => {
590-
let quoted = format!("\"{}\"", $text);
591-
assert_eq!(quoted, serde_json::to_string(&$variant).unwrap());
592-
assert_eq!($variant, serde_json::from_str(&quoted).unwrap());
593-
594-
assert_eq!($text, format!("{}", $variant));
595-
};
596-
}
597-
598558
#[test]
599-
fn test_oauth_flow_ser_deser() {
600-
test_ser_deser!(OAuthFlow, OAuthFlow::DeviceCode, "DeviceCode");
601-
test_ser_deser!(OAuthFlow, OAuthFlow::Pkce, "PKCE");
559+
fn test_oauth_flow_deser() {
560+
assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"PKCE\"").unwrap());
561+
assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"Pkce\"").unwrap());
602562
}
603563

604564
#[tokio::test]

crates/chat-cli/src/auth/pkce.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ impl PkceRegistration {
231231
/// then the access and refresh tokens will be saved.
232232
///
233233
/// Only the first connection will be served.
234-
pub async fn finish<C: PkceClient>(self, client: &C, database: Option<&Database>) -> Result<(), AuthError> {
234+
pub async fn finish<C: PkceClient>(self, client: &C, database: Option<&mut Database>) -> Result<(), AuthError> {
235235
let code = tokio::select! {
236236
code = Self::recv_code(self.listener, self.state) => {
237237
code?
@@ -270,11 +270,11 @@ impl PkceRegistration {
270270
);
271271

272272
if let Some(database) = database {
273-
if let Err(err) = device_registration.save(&database.secret_store).await {
273+
if let Err(err) = device_registration.save(database).await {
274274
error!(?err, "Failed to store pkce registration to secret store");
275275
}
276276

277-
if let Err(err) = token.save(&database.secret_store).await {
277+
if let Err(err) = token.save(database).await {
278278
error!(?err, "Failed to store builder id token");
279279
};
280280
}

crates/chat-cli/src/cli/debug.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use clap::{
2-
Subcommand,
3-
ValueEnum,
4-
};
1+
use clap::ValueEnum;
52

63
#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)]
74
pub enum Build {
@@ -65,7 +62,7 @@ pub enum TISAction {
6562
use std::path::PathBuf;
6663

6764
#[cfg(target_os = "macos")]
68-
#[derive(Debug, Subcommand, Clone, PartialEq, Eq)]
65+
#[derive(Debug, clap::Subcommand, Clone, PartialEq, Eq)]
6966
pub enum InputMethodDebugAction {
7067
Install {
7168
bundle_path: Option<PathBuf>,

0 commit comments

Comments
 (0)