@@ -395,387 +395,3 @@ async fn bind_allowed_port(ports: &[u16]) -> Result<TcpListener, AuthError> {
395395fn get_auth_portal_url ( ) -> String {
396396 env:: var ( "KIRO_AUTH_PORTAL_URL" ) . unwrap_or_else ( |_| DEFAULT_AUTH_PORTAL_URL . to_string ( ) )
397397}
398- //! Unified auth portal integration for streamlined authentication
399- //! Handles callbacks from https://app.kiro.dev/signin
400-
401- use std:: time:: Duration ;
402-
403- use bytes:: Bytes ;
404- use http_body_util:: Full ;
405- use hyper:: body:: Incoming ;
406- use hyper:: server:: conn:: http1;
407- use hyper:: service:: Service ;
408- use hyper:: {
409- Request ,
410- Response ,
411- } ;
412- use hyper_util:: rt:: TokioIo ;
413- use rand:: Rng ;
414- use tokio:: net:: TcpListener ;
415- use tracing:: {
416- debug,
417- error,
418- info,
419- warn,
420- } ;
421-
422- use crate :: auth:: AuthError ;
423- use crate :: auth:: pkce:: {
424- generate_code_challenge,
425- generate_code_verifier,
426- } ;
427- use crate :: auth:: social:: {
428- CALLBACK_PORTS ,
429- SocialProvider ,
430- SocialToken ,
431- } ;
432- use crate :: database:: Database ;
433- use crate :: util:: system_info:: is_mwinit_available;
434-
435- const AUTH_PORTAL_URL : & str = "https://app.kiro.dev/signin" ;
436- const DEFAULT_AUTHORIZATION_TIMEOUT : Duration = Duration :: from_secs ( 600 ) ;
437-
438- #[ derive( Debug , Clone ) ]
439- struct AuthPortalCallback {
440- login_option : String ,
441- code : Option < String > ,
442- issuer_url : Option < String > ,
443- sso_region : Option < String > ,
444- state : String ,
445- path : String ,
446- error : Option < String > ,
447- error_description : Option < String > ,
448- }
449-
450- pub enum PortalResult {
451- Social ( SocialProvider ) ,
452- BuilderId {
453- issuer_url : String ,
454- idc_region : String ,
455- } ,
456- AwsIdc {
457- issuer_url : String ,
458- idc_region : String ,
459- } ,
460- /// Internal amazon user
461- Internal {
462- issuer_url : String ,
463- idc_region : String ,
464- } ,
465- }
466-
467- /// Local-only: open unified portal and handle single callback
468- pub async fn start_unified_auth ( db : & mut Database ) -> Result < PortalResult , AuthError > {
469- info ! ( "Starting unified auth portal flow" ) ;
470-
471- // PKCE params for portal + social token exchange
472- let verifier = generate_code_verifier ( ) ;
473- let challenge = generate_code_challenge ( & verifier) ;
474- let state = rand:: rng ( )
475- . sample_iter ( rand:: distr:: Alphanumeric )
476- . take ( 10 )
477- . collect :: < Vec < _ > > ( ) ;
478- let state = String :: from_utf8 ( state) . unwrap_or ( "state" . to_string ( ) ) ;
479-
480- let listener = bind_allowed_port ( CALLBACK_PORTS ) . await ?;
481- let port = listener. local_addr ( ) ?. port ( ) ;
482-
483- let redirect_base = format ! ( "http://localhost:{port}" ) ;
484- info ! ( %port, %redirect_base, "Unified auth portal listening for callback" ) ;
485-
486- let auth_url = build_auth_url ( & redirect_base, & state, & challenge) ;
487-
488- crate :: util:: open:: open_url_async ( & auth_url)
489- . await
490- . map_err ( |e| AuthError :: OAuthCustomError ( format ! ( "Failed to open browser: {e}" ) ) ) ?;
491-
492- let callback = wait_for_auth_callback ( listener, state. clone ( ) ) . await ?;
493-
494- if let Some ( error) = & callback. error {
495- let friendly_msg =
496- format_user_friendly_error ( error, callback. error_description . as_deref ( ) , & callback. login_option ) ;
497-
498- warn ! (
499- "OAuth error for {}: {} - {}" ,
500- callback. login_option, error, friendly_msg
501- ) ;
502-
503- return Err ( match callback. login_option . as_str ( ) {
504- "google" | "github" => AuthError :: SocialAuthProviderFailure ( friendly_msg) ,
505- _ => AuthError :: OAuthCustomError ( friendly_msg) ,
506- } ) ;
507- }
508-
509- process_portal_callback ( db, callback, port, & verifier) . await
510- }
511-
512- fn format_user_friendly_error ( error_code : & str , description : Option < & str > , provider : & str ) -> String {
513- let cleaned_description = description. map ( |d| {
514- let first_part = d. split ( ';' ) . next ( ) . unwrap_or ( d) ;
515- // Replace + with spaces (URL encoding)
516- first_part. replace ( '+' , " " ) . trim ( ) . to_string ( )
517- } ) ;
518-
519- match error_code {
520- "access_denied" => {
521- format ! ( "{provider} denied access to Kiro. Please ensure you grant all required permissions." )
522- } ,
523- "invalid_request" => "Authentication failed due to an invalid request. Please try again." . to_string ( ) ,
524- "unauthorized_client" => "The application is not authorized. Please contact support." . to_string ( ) ,
525- "server_error" => {
526- format ! ( "{provider} login is temporarily unavailable. Please try again later." )
527- } ,
528- "invalid_scope" => "The requested permissions are invalid. Please contact support." . to_string ( ) ,
529- _ => {
530- // For unknown errors, use cleaned description or a generic message
531- cleaned_description. unwrap_or_else ( || format ! ( "Authentication failed: {error_code}. Please try again." ) )
532- } ,
533- }
534- }
535-
536- /// Build the authorization URL with all required parameters
537- fn build_auth_url ( redirect_base : & str , state : & str , challenge : & str ) -> String {
538- let is_internal = is_mwinit_available ( ) ;
539- let internal_param = if is_internal { "&from_amazon_internal=true" } else { "" } ;
540-
541- format ! (
542- "{}?state={}&code_challenge={}&code_challenge_method=S256&redirect_uri={}{}&redirect_from=kirocli" ,
543- AUTH_PORTAL_URL ,
544- state,
545- challenge,
546- urlencoding:: encode( redirect_base) ,
547- internal_param
548- )
549- }
550-
551- async fn process_portal_callback (
552- db : & mut Database ,
553- callback : AuthPortalCallback ,
554- port : u16 ,
555- verifier : & str ,
556- ) -> Result < PortalResult , AuthError > {
557- match callback. login_option . as_str ( ) {
558- "google" | "github" => handle_social_callback ( db, callback, port, verifier) . await ,
559- "internal" => {
560- let ( issuer_url, sso_region) = extract_sso_params ( & callback, "internal" ) ?;
561- Ok ( PortalResult :: Internal {
562- issuer_url,
563- idc_region : sso_region,
564- } )
565- } ,
566- "awsidc" => {
567- let ( issuer_url, sso_region) = extract_sso_params ( & callback, "awsIdc" ) ?;
568- Ok ( PortalResult :: AwsIdc {
569- issuer_url,
570- idc_region : sso_region,
571- } )
572- } ,
573- "builderid" => {
574- let ( issuer_url, sso_region) = extract_sso_params ( & callback, "builderId" ) ?;
575- Ok ( PortalResult :: BuilderId {
576- issuer_url,
577- idc_region : sso_region,
578- } )
579- } ,
580- other => Err ( AuthError :: OAuthCustomError ( format ! ( "Unknown login_option: {other}" ) ) ) ,
581- }
582- }
583-
584- /// Handle social provider callback (Google/GitHub)
585- async fn handle_social_callback (
586- db : & mut Database ,
587- callback : AuthPortalCallback ,
588- port : u16 ,
589- verifier : & str ,
590- ) -> Result < PortalResult , AuthError > {
591- let provider = match callback. login_option . as_str ( ) {
592- "google" => SocialProvider :: Google ,
593- "github" => SocialProvider :: Github ,
594- _ => unreachable ! ( ) ,
595- } ;
596-
597- let code = callback. code . ok_or ( AuthError :: OAuthMissingCode ) ?;
598- let redirect_uri = format ! (
599- "http://localhost:{}{}?login_option={}" ,
600- port,
601- callback. path,
602- urlencoding:: encode( & callback. login_option)
603- ) ;
604-
605- SocialToken :: exchange_social_token ( db, provider, verifier, & code, & redirect_uri) . await ?;
606- Ok ( PortalResult :: Social ( provider) )
607- }
608-
609- /// Extract issuer_url and sso_region from callback, returning descriptive error if missing
610- fn extract_sso_params ( callback : & AuthPortalCallback , auth_type : & str ) -> Result < ( String , String ) , AuthError > {
611- let issuer_url = callback
612- . issuer_url
613- . clone ( )
614- . ok_or_else ( || AuthError :: OAuthCustomError ( format ! ( "Missing issuer_url for {auth_type} auth" ) ) ) ?;
615-
616- let sso_region = callback
617- . sso_region
618- . clone ( )
619- . ok_or_else ( || AuthError :: OAuthCustomError ( format ! ( "Missing sso_region for {auth_type} auth" ) ) ) ?;
620-
621- Ok ( ( issuer_url, sso_region) )
622- }
623-
624- async fn wait_for_auth_callback (
625- listener : TcpListener ,
626- expected_state : String ,
627- ) -> Result < AuthPortalCallback , AuthError > {
628- let ( tx, mut rx) = tokio:: sync:: mpsc:: channel :: < AuthPortalCallback > ( 1 ) ;
629-
630- let server_handle = tokio:: spawn ( async move {
631- const MAX_CONNECTIONS : usize = 3 ;
632- let mut count = 0 ;
633-
634- loop {
635- if count >= MAX_CONNECTIONS {
636- warn ! ( "Reached max connections ({})" , MAX_CONNECTIONS ) ;
637- break ;
638- }
639-
640- match listener. accept ( ) . await {
641- Ok ( ( stream, _) ) => {
642- count += 1 ;
643- debug ! ( "Connection {}/{}" , count, MAX_CONNECTIONS ) ;
644-
645- let io = TokioIo :: new ( stream) ;
646- let service = AuthCallbackService { tx : tx. clone ( ) } ;
647-
648- tokio:: spawn ( async move {
649- let _ = http1:: Builder :: new ( ) . serve_connection ( io, service) . await ;
650- } ) ;
651- } ,
652- Err ( e) => {
653- error ! ( "Accept failed: {}" , e) ;
654- break ;
655- } ,
656- }
657- }
658- } ) ;
659-
660- let callback = tokio:: select! {
661- result = rx. recv( ) => {
662- result. ok_or( AuthError :: OAuthCustomError ( "Failed to receive callback" . into( ) ) ) ?
663- } ,
664- _ = tokio:: time:: sleep( DEFAULT_AUTHORIZATION_TIMEOUT ) => {
665- return Err ( AuthError :: OAuthTimeout ) ;
666- }
667- } ;
668-
669- server_handle. abort ( ) ;
670-
671- if callback. state != expected_state {
672- return Err ( AuthError :: OAuthStateMismatch {
673- actual : callback. state ,
674- expected : expected_state,
675- } ) ;
676- }
677-
678- Ok ( callback)
679- }
680-
681- #[ derive( Clone ) ]
682- struct AuthCallbackService {
683- tx : tokio:: sync:: mpsc:: Sender < AuthPortalCallback > ,
684- }
685-
686- impl Service < Request < Incoming > > for AuthCallbackService {
687- type Error = AuthError ;
688- type Future = std:: pin:: Pin < Box < dyn std:: future:: Future < Output = Result < Self :: Response , Self :: Error > > + Send > > ;
689- type Response = Response < Full < Bytes > > ;
690-
691- fn call ( & self , req : Request < Incoming > ) -> Self :: Future {
692- let tx = self . tx . clone ( ) ;
693-
694- Box :: pin ( async move {
695- let uri = req. uri ( ) ;
696- let path = uri. path ( ) ;
697-
698- if path == "/oauth/callback" || path == "/signin/callback" {
699- handle_valid_callback ( uri, path, tx) . await
700- } else {
701- handle_invalid_callback ( path) . await
702- }
703- } )
704- }
705- }
706-
707- /// Handle valid callback paths
708- async fn handle_valid_callback (
709- uri : & hyper:: Uri ,
710- path : & str ,
711- tx : tokio:: sync:: mpsc:: Sender < AuthPortalCallback > ,
712- ) -> Result < Response < Full < Bytes > > , AuthError > {
713- let query_params = uri
714- . query ( )
715- . map ( |query| {
716- query
717- . split ( '&' )
718- . filter_map ( |kv| {
719- kv. split_once ( '=' )
720- . map ( |( k, v) | ( k. to_string ( ) , urlencoding:: decode ( v) . unwrap_or_default ( ) . to_string ( ) ) )
721- } )
722- . collect :: < std:: collections:: HashMap < String , String > > ( ) //
723- } )
724- . ok_or ( AuthError :: OAuthCustomError ( "query parameters are missing" . into ( ) ) ) ?;
725-
726- let callback = AuthPortalCallback {
727- login_option : query_params. get ( "login_option" ) . cloned ( ) . unwrap_or_default ( ) ,
728- code : query_params. get ( "code" ) . cloned ( ) ,
729- issuer_url : query_params. get ( "issuer_url" ) . cloned ( ) ,
730- sso_region : query_params. get ( "idc_region" ) . cloned ( ) ,
731- state : query_params. get ( "state" ) . cloned ( ) . unwrap_or_default ( ) ,
732- path : path. to_string ( ) ,
733- error : query_params. get ( "error" ) . cloned ( ) ,
734- error_description : query_params. get ( "error_description" ) . cloned ( ) ,
735- } ;
736-
737- let _ = tx. send ( callback. clone ( ) ) . await ;
738-
739- if let Some ( error) = & callback. error {
740- let error_msg = callback. error_description . as_deref ( ) . unwrap_or ( error. as_str ( ) ) ;
741- build_redirect_response ( "error" , Some ( error_msg) )
742- } else {
743- build_redirect_response ( "success" , None )
744- }
745- }
746-
747- async fn handle_invalid_callback ( path : & str ) -> Result < Response < Full < Bytes > > , AuthError > {
748- info ! ( %path, "Invalid callback path: {}, redirecting to portal" , path) ;
749- build_redirect_response ( "error" , Some ( "Invalid callback path" ) )
750- }
751-
752- /// Build a redirect response to the auth portal
753- fn build_redirect_response ( status : & str , error_message : Option < & str > ) -> Result < Response < Full < Bytes > > , AuthError > {
754- let mut redirect_url = format ! ( "{AUTH_PORTAL_URL}?auth_status={status}&redirect_from=kirocli" ) ;
755-
756- if let Some ( msg) = error_message {
757- redirect_url. push_str ( & format ! ( "&error_message={}" , urlencoding:: encode( msg) ) ) ;
758- }
759-
760- Ok ( Response :: builder ( )
761- . status ( 302 )
762- . header ( "Location" , redirect_url)
763- . header ( "Cache-Control" , "no-store" )
764- . body ( Full :: new ( Bytes :: from ( "" ) ) )
765- . expect ( "valid response" ) )
766- }
767-
768- async fn bind_allowed_port ( ports : & [ u16 ] ) -> Result < TcpListener , AuthError > {
769- for port in ports {
770- match TcpListener :: bind ( ( "127.0.0.1" , * port) ) . await {
771- Ok ( listener) => return Ok ( listener) ,
772- Err ( e) => {
773- debug ! ( "Failed to bind to port {}: {}" , port, e) ;
774- } ,
775- }
776- }
777-
778- Err ( AuthError :: OAuthCustomError (
779- "All callback ports are in use. Please close some applications and try again." . into ( ) ,
780- ) )
781- }
0 commit comments