@@ -45,10 +45,16 @@ use aws_types::request_id::RequestId;
4545use aws_types:: sdk_config:: StalledStreamProtectionConfig ;
4646use fig_aws_common:: app_name;
4747use fig_telemetry_core:: {
48+ AuthErrorType ,
4849 Event ,
4950 EventType ,
5051 TelemetryResult ,
5152} ;
53+ use fig_util:: auth:: {
54+ OAuthFlow ,
55+ START_URL ,
56+ TokenType ,
57+ } ;
5258use time:: OffsetDateTime ;
5359use tracing:: {
5460 debug,
@@ -69,22 +75,6 @@ use crate::{
6975 Result ,
7076} ;
7177
72- #[ derive( Debug , Copy , Clone , PartialEq , Eq , serde:: Serialize , serde:: Deserialize ) ]
73- pub enum OAuthFlow {
74- DeviceCode ,
75- #[ serde( alias = "Pkce" ) ]
76- PKCE ,
77- }
78-
79- impl std:: fmt:: Display for OAuthFlow {
80- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
81- match * self {
82- OAuthFlow :: DeviceCode => write ! ( f, "DeviceCode" ) ,
83- OAuthFlow :: PKCE => write ! ( f, "PKCE" ) ,
84- }
85- }
86- }
87-
8878/// Indicates if an expiration time has passed, there is a small 1 min window that is removed
8979/// so the token will not expire in transit
9080fn is_expired ( expiration_time : & OffsetDateTime ) -> bool {
@@ -254,30 +244,42 @@ pub async fn start_device_authorization(
254244 ..
255245 } = DeviceRegistration :: init_device_code_registration ( & client, secret_store, & region) . await ?;
256246
247+ let start_url = start_url. as_deref ( ) . unwrap_or ( START_URL ) ;
248+
257249 let output = client
258250 . start_device_authorization ( )
259251 . client_id ( & client_id)
260252 . client_secret ( & client_secret. 0 )
261- . start_url ( start_url. as_deref ( ) . unwrap_or ( START_URL ) )
253+ . start_url ( start_url)
262254 . send ( )
263- . await ?;
264-
265- Ok ( StartDeviceAuthorizationResponse {
266- device_code : output. device_code . unwrap_or_default ( ) ,
267- user_code : output. user_code . unwrap_or_default ( ) ,
268- verification_uri : output. verification_uri . unwrap_or_default ( ) ,
269- verification_uri_complete : output. verification_uri_complete . unwrap_or_default ( ) ,
270- expires_in : output. expires_in ,
271- interval : output. interval ,
272- region : region. to_string ( ) ,
273- start_url : start_url. unwrap_or_else ( || START_URL . to_owned ( ) ) ,
274- } )
275- }
276-
277- #[ derive( Debug , Clone , PartialEq , Eq ) ]
278- pub enum TokenType {
279- BuilderId ,
280- IamIdentityCenter ,
255+ . await ;
256+
257+ match output {
258+ Ok ( output) => Ok ( StartDeviceAuthorizationResponse {
259+ device_code : output. device_code . unwrap_or_default ( ) ,
260+ user_code : output. user_code . unwrap_or_default ( ) ,
261+ verification_uri : output. verification_uri . unwrap_or_default ( ) ,
262+ verification_uri_complete : output. verification_uri_complete . unwrap_or_default ( ) ,
263+ expires_in : output. expires_in ,
264+ interval : output. interval ,
265+ region : region. to_string ( ) ,
266+ start_url : start_url. to_string ( ) ,
267+ } ) ,
268+ Err ( err) => {
269+ let err: Error = err. into ( ) ;
270+ fig_telemetry_core:: send_event (
271+ Event :: new ( EventType :: AuthFailed {
272+ auth_method : TokenType :: from_start_url ( Some ( start_url) ) ,
273+ oauth_flow : OAuthFlow :: DeviceCode ,
274+ error_type : AuthErrorType :: NewLogin ,
275+ error_code : err. service_error_code ( ) ,
276+ } )
277+ . with_credential_start_url ( start_url. to_string ( ) ) ,
278+ )
279+ . await ;
280+ Err ( err)
281+ } ,
282+ }
281283}
282284
283285#[ derive( Debug , Clone , serde:: Serialize , serde:: Deserialize ) ]
@@ -512,6 +514,27 @@ pub async fn poll_create_token(
512514 device_code : String ,
513515 start_url : Option < String > ,
514516 region : Option < String > ,
517+ ) -> PollCreateToken {
518+ match poll_create_token_impl ( secret_store, device_code, start_url. clone ( ) , region) . await {
519+ PollCreateToken :: Error ( err) => {
520+ fig_telemetry_core:: send_event ( Event :: new ( EventType :: AuthFailed {
521+ auth_method : TokenType :: from_start_url ( start_url. as_deref ( ) ) ,
522+ oauth_flow : OAuthFlow :: DeviceCode ,
523+ error_type : AuthErrorType :: NewLogin ,
524+ error_code : err. service_error_code ( ) ,
525+ } ) )
526+ . await ;
527+ PollCreateToken :: Error ( err)
528+ } ,
529+ other => other,
530+ }
531+ }
532+
533+ async fn poll_create_token_impl (
534+ secret_store : & SecretStore ,
535+ device_code : String ,
536+ start_url : Option < String > ,
537+ region : Option < String > ,
515538) -> PollCreateToken {
516539 let region = region. clone ( ) . map_or ( OIDC_BUILDER_ID_REGION , Region :: new) ;
517540 let client = client ( region. clone ( ) ) ;
@@ -634,23 +657,6 @@ mod tests {
634657 const US_EAST_1 : Region = Region :: from_static ( "us-east-1" ) ;
635658 const US_WEST_2 : Region = Region :: from_static ( "us-west-2" ) ;
636659
637- macro_rules! test_ser_deser {
638- ( $ty: ident, $variant: expr, $text: expr) => {
639- let quoted = format!( "\" {}\" " , $text) ;
640- assert_eq!( quoted, serde_json:: to_string( & $variant) . unwrap( ) ) ;
641- assert_eq!( $variant, serde_json:: from_str( & quoted) . unwrap( ) ) ;
642-
643- assert_eq!( $text, format!( "{}" , $variant) ) ;
644- } ;
645- }
646-
647- #[ test]
648- fn test_oauth_flow_ser_deser ( ) {
649- test_ser_deser ! ( OAuthFlow , OAuthFlow :: DeviceCode , "DeviceCode" ) ;
650- test_ser_deser ! ( OAuthFlow , OAuthFlow :: PKCE , "PKCE" ) ;
651- assert_eq ! ( OAuthFlow :: PKCE , serde_json:: from_str( "\" Pkce\" " ) . unwrap( ) ) ;
652- }
653-
654660 #[ test]
655661 fn test_client ( ) {
656662 println ! ( "{:?}" , client( US_EAST_1 ) ) ;
0 commit comments