@@ -52,49 +52,19 @@ use crate::auth::AuthError;
5252use crate :: auth:: consts:: * ;
5353use crate :: auth:: scope:: is_scopes;
5454use crate :: aws_common:: app_name;
55- use crate :: database:: Database ;
56- use crate :: database :: secret_store :: {
55+ use crate :: database:: {
56+ Database ,
5757 Secret ,
58- SecretStore ,
5958} ;
6059
61- #[ derive( Debug , Copy , Clone , PartialEq , Eq , serde:: Deserialize ) ]
60+ #[ derive( Debug , Copy , Clone , PartialEq , Eq , serde:: Serialize , serde :: Deserialize ) ]
6261pub enum OAuthFlow {
6362 DeviceCode ,
6463 // This must remain backwards compatible
6564 #[ serde( alias = "PKCE" ) ]
6665 Pkce ,
6766}
6867
69- // Implement Serialize manually to ensure proper serialization
70- impl serde:: Serialize for OAuthFlow {
71- fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
72- where
73- S : serde:: Serializer ,
74- {
75- match * self {
76- OAuthFlow :: DeviceCode => serializer. serialize_str ( "DeviceCode" ) ,
77- OAuthFlow :: Pkce => serialize_pkce ( serializer) ,
78- }
79- }
80- }
81-
82- fn serialize_pkce < S > ( serializer : S ) -> Result < S :: Ok , S :: Error >
83- where
84- S : serde:: Serializer ,
85- {
86- serializer. serialize_str ( "PKCE" )
87- }
88-
89- impl std:: fmt:: Display for OAuthFlow {
90- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
91- match * self {
92- OAuthFlow :: DeviceCode => write ! ( f, "DeviceCode" ) ,
93- OAuthFlow :: Pkce => write ! ( f, "PKCE" ) ,
94- }
95- }
96- }
97-
9868/// Indicates if an expiration time has passed, there is a small 1 min window that is removed
9969/// so the token will not expire in transit
10070fn is_expired ( expiration_time : & OffsetDateTime ) -> bool {
@@ -152,8 +122,8 @@ impl DeviceRegistration {
152122 }
153123
154124 /// Loads the OIDC registered client from the secret store, deleting it if it is expired.
155- async fn load_from_secret_store ( secret_store : & SecretStore , region : & Region ) -> Result < Option < Self > , AuthError > {
156- let device_registration = secret_store . get ( Self :: SECRET_KEY ) . await ?;
125+ async fn load_from_secret_store ( database : & Database , region : & Region ) -> Result < Option < Self > , AuthError > {
126+ let device_registration = database . get_secret ( Self :: SECRET_KEY ) . await ?;
157127
158128 if let Some ( device_registration) = device_registration {
159129 // check that the data is not expired, assume it is invalid if not present
@@ -167,7 +137,7 @@ impl DeviceRegistration {
167137 }
168138
169139 // delete the data if its expired or invalid
170- if let Err ( err) = secret_store . delete ( Self :: SECRET_KEY ) . await {
140+ if let Err ( err) = database . delete_secret ( Self :: SECRET_KEY ) . await {
171141 error ! ( ?err, "Failed to delete device registration from keychain" ) ;
172142 }
173143
@@ -181,7 +151,7 @@ impl DeviceRegistration {
181151 client : & Client ,
182152 region : & Region ,
183153 ) -> Result < Self , AuthError > {
184- match Self :: load_from_secret_store ( & database. secret_store , region) . await {
154+ match Self :: load_from_secret_store ( database, region) . await {
185155 Ok ( Some ( registration) ) if registration. oauth_flow == OAuthFlow :: DeviceCode => match & registration. scopes {
186156 Some ( scopes) if is_scopes ( scopes) => return Ok ( registration) ,
187157 _ => warn ! ( "Invalid scopes in device registration, ignoring" ) ,
@@ -210,17 +180,17 @@ impl DeviceRegistration {
210180 SCOPES . iter ( ) . map ( |s| ( * s) . to_owned ( ) ) . collect ( ) ,
211181 ) ;
212182
213- if let Err ( err) = device_registration. save ( & database. secret_store ) . await {
183+ if let Err ( err) = device_registration. save ( database) . await {
214184 error ! ( ?err, "Failed to write device registration to keychain" ) ;
215185 }
216186
217187 Ok ( device_registration)
218188 }
219189
220190 /// Saves to the passed secret store.
221- pub async fn save ( & self , secret_store : & SecretStore ) -> Result < ( ) , AuthError > {
191+ pub async fn save ( & self , secret_store : & Database ) -> Result < ( ) , AuthError > {
222192 secret_store
223- . set ( Self :: SECRET_KEY , & serde_json:: to_string ( & self ) ?)
193+ . set_secret ( Self :: SECRET_KEY , & serde_json:: to_string ( & self ) ?)
224194 . await ?;
225195 Ok ( ( ) )
226196 }
@@ -314,8 +284,8 @@ impl BuilderIdToken {
314284 }
315285
316286 /// Load the token from the keychain, refresh the token if it is expired and return it
317- pub async fn load ( database : & mut Database ) -> Result < Option < Self > , AuthError > {
318- match database. secret_store . get ( Self :: SECRET_KEY ) . await {
287+ pub async fn load ( database : & Database ) -> Result < Option < Self > , AuthError > {
288+ match database. get_secret ( Self :: SECRET_KEY ) . await {
319289 Ok ( Some ( secret) ) => {
320290 let token: Option < Self > = serde_json:: from_str ( & secret. 0 ) ?;
321291 match token {
@@ -325,7 +295,7 @@ impl BuilderIdToken {
325295 let client = client ( region. clone ( ) ) ;
326296 // if token is expired try to refresh
327297 if token. is_expired ( ) {
328- token. refresh_token ( & client, & database. secret_store , & region) . await
298+ token. refresh_token ( & client, database, & region) . await
329299 } else {
330300 Ok ( Some ( token) )
331301 }
@@ -345,19 +315,19 @@ impl BuilderIdToken {
345315 pub async fn refresh_token (
346316 & self ,
347317 client : & Client ,
348- secret_store : & SecretStore ,
318+ database : & Database ,
349319 region : & Region ,
350320 ) -> Result < Option < Self > , AuthError > {
351321 let Some ( refresh_token) = & self . refresh_token else {
352322 // if the token is expired and has no refresh token, delete it
353- if let Err ( err) = self . delete ( secret_store ) . await {
323+ if let Err ( err) = self . delete ( database ) . await {
354324 error ! ( ?err, "Failed to delete builder id token" ) ;
355325 }
356326
357327 return Ok ( None ) ;
358328 } ;
359329
360- let registration = match DeviceRegistration :: load_from_secret_store ( secret_store , region) . await ? {
330+ let registration = match DeviceRegistration :: load_from_secret_store ( database , region) . await ? {
361331 Some ( registration) if registration. oauth_flow == self . oauth_flow => registration,
362332 // If the OIDC client registration is for a different oauth flow or doesn't exist, then
363333 // we can't refresh the token.
@@ -394,7 +364,7 @@ impl BuilderIdToken {
394364 ) ;
395365 debug ! ( "Refreshed access token, new token: {:?}" , token) ;
396366
397- if let Err ( err) = token. save ( secret_store ) . await {
367+ if let Err ( err) = token. save ( database ) . await {
398368 error ! ( ?err, "Failed to store builder id access token" ) ;
399369 } ;
400370
@@ -407,7 +377,7 @@ impl BuilderIdToken {
407377 // if the error is the client's fault, clear the token
408378 if let SdkError :: ServiceError ( service_err) = & err {
409379 if !service_err. err ( ) . is_slow_down_exception ( ) {
410- if let Err ( err) = self . delete ( secret_store ) . await {
380+ if let Err ( err) = self . delete ( database ) . await {
411381 error ! ( ?err, "Failed to delete builder id token" ) ;
412382 }
413383 }
@@ -427,16 +397,16 @@ impl BuilderIdToken {
427397 }
428398
429399 /// Save the token to the keychain
430- pub async fn save ( & self , secret_store : & SecretStore ) -> Result < ( ) , AuthError > {
431- secret_store
432- . set ( Self :: SECRET_KEY , & serde_json:: to_string ( self ) ?)
400+ pub async fn save ( & self , database : & Database ) -> Result < ( ) , AuthError > {
401+ database
402+ . set_secret ( Self :: SECRET_KEY , & serde_json:: to_string ( self ) ?)
433403 . await ?;
434404 Ok ( ( ) )
435405 }
436406
437407 /// Delete the token from the keychain
438- pub async fn delete ( & self , secret_store : & SecretStore ) -> Result < ( ) , AuthError > {
439- secret_store . delete ( Self :: SECRET_KEY ) . await ?;
408+ pub async fn delete ( & self , database : & Database ) -> Result < ( ) , AuthError > {
409+ database . delete_secret ( Self :: SECRET_KEY ) . await ?;
440410 Ok ( ( ) )
441411 }
442412
@@ -508,7 +478,7 @@ pub async fn poll_create_token(
508478 let token: BuilderIdToken =
509479 BuilderIdToken :: from_output ( output, region, start_url, OAuthFlow :: DeviceCode , scopes) ;
510480
511- if let Err ( err) = token. save ( & database. secret_store ) . await {
481+ if let Err ( err) = token. save ( database) . await {
512482 error ! ( ?err, "Failed to store builder id token" ) ;
513483 } ;
514484
@@ -529,13 +499,13 @@ pub async fn is_logged_in(database: &mut Database) -> bool {
529499}
530500
531501pub async fn logout ( database : & mut Database ) -> Result < ( ) , AuthError > {
532- let Ok ( secret_store) = SecretStore :: new ( ) . await else {
502+ let Ok ( secret_store) = Database :: new ( ) . await else {
533503 return Ok ( ( ) ) ;
534504 } ;
535505
536506 let ( builder_res, device_res) = tokio:: join!(
537- secret_store. delete ( BuilderIdToken :: SECRET_KEY ) ,
538- secret_store. delete ( DeviceRegistration :: SECRET_KEY ) ,
507+ secret_store. delete_secret ( BuilderIdToken :: SECRET_KEY ) ,
508+ secret_store. delete_secret ( DeviceRegistration :: SECRET_KEY ) ,
539509 ) ;
540510
541511 let profile_res = database. unset_auth_profile ( ) ;
@@ -585,20 +555,10 @@ mod tests {
585555 const US_EAST_1 : Region = Region :: from_static ( "us-east-1" ) ;
586556 const US_WEST_2 : Region = Region :: from_static ( "us-west-2" ) ;
587557
588- macro_rules! test_ser_deser {
589- ( $ty: ident, $variant: expr, $text: expr) => {
590- let quoted = format!( "\" {}\" " , $text) ;
591- assert_eq!( quoted, serde_json:: to_string( & $variant) . unwrap( ) ) ;
592- assert_eq!( $variant, serde_json:: from_str( & quoted) . unwrap( ) ) ;
593-
594- assert_eq!( $text, format!( "{}" , $variant) ) ;
595- } ;
596- }
597-
598558 #[ test]
599- fn test_oauth_flow_ser_deser ( ) {
600- test_ser_deser ! ( OAuthFlow , OAuthFlow :: DeviceCode , "DeviceCode" ) ;
601- test_ser_deser ! ( OAuthFlow , OAuthFlow :: Pkce , "PKCE" ) ;
559+ fn test_oauth_flow_deser ( ) {
560+ assert_eq ! ( OAuthFlow :: Pkce , serde_json :: from_str ( " \" PKCE \" " ) . unwrap ( ) ) ;
561+ assert_eq ! ( OAuthFlow :: Pkce , serde_json :: from_str ( " \" Pkce \" " ) . unwrap ( ) ) ;
602562 }
603563
604564 #[ tokio:: test]
0 commit comments