@@ -395,3 +395,387 @@ async fn bind_allowed_port(ports: &[u16]) -> Result<TcpListener, AuthError> {
395395fn get_auth_portal_url ( ) -> String {
396396 env:: var ( "KIRO_AUTH_PORTAL_URL" ) . unwrap_or_else ( |_| DEFAULT_AUTH_PORTAL_URL . to_string ( ) )
397397}
398+ //! Unified auth portal integration for streamlined authentication
399+ //! Handles callbacks from https://app.kiro.dev/signin
400+
401+ use std:: time:: Duration ;
402+
403+ use bytes:: Bytes ;
404+ use http_body_util:: Full ;
405+ use hyper:: body:: Incoming ;
406+ use hyper:: server:: conn:: http1;
407+ use hyper:: service:: Service ;
408+ use hyper:: {
409+ Request ,
410+ Response ,
411+ } ;
412+ use hyper_util:: rt:: TokioIo ;
413+ use rand:: Rng ;
414+ use tokio:: net:: TcpListener ;
415+ use tracing:: {
416+ debug,
417+ error,
418+ info,
419+ warn,
420+ } ;
421+
422+ use crate :: auth:: AuthError ;
423+ use crate :: auth:: pkce:: {
424+ generate_code_challenge,
425+ generate_code_verifier,
426+ } ;
427+ use crate :: auth:: social:: {
428+ CALLBACK_PORTS ,
429+ SocialProvider ,
430+ SocialToken ,
431+ } ;
432+ use crate :: database:: Database ;
433+ use crate :: util:: system_info:: is_mwinit_available;
434+
435+ const AUTH_PORTAL_URL : & str = "https://app.kiro.dev/signin" ;
436+ const DEFAULT_AUTHORIZATION_TIMEOUT : Duration = Duration :: from_secs ( 600 ) ;
437+
438+ #[ derive( Debug , Clone ) ]
439+ struct AuthPortalCallback {
440+ login_option : String ,
441+ code : Option < String > ,
442+ issuer_url : Option < String > ,
443+ sso_region : Option < String > ,
444+ state : String ,
445+ path : String ,
446+ error : Option < String > ,
447+ error_description : Option < String > ,
448+ }
449+
450+ pub enum PortalResult {
451+ Social ( SocialProvider ) ,
452+ BuilderId {
453+ issuer_url : String ,
454+ idc_region : String ,
455+ } ,
456+ AwsIdc {
457+ issuer_url : String ,
458+ idc_region : String ,
459+ } ,
460+ /// Internal amazon user
461+ Internal {
462+ issuer_url : String ,
463+ idc_region : String ,
464+ } ,
465+ }
466+
467+ /// Local-only: open unified portal and handle single callback
468+ pub async fn start_unified_auth ( db : & mut Database ) -> Result < PortalResult , AuthError > {
469+ info ! ( "Starting unified auth portal flow" ) ;
470+
471+ // PKCE params for portal + social token exchange
472+ let verifier = generate_code_verifier ( ) ;
473+ let challenge = generate_code_challenge ( & verifier) ;
474+ let state = rand:: rng ( )
475+ . sample_iter ( rand:: distr:: Alphanumeric )
476+ . take ( 10 )
477+ . collect :: < Vec < _ > > ( ) ;
478+ let state = String :: from_utf8 ( state) . unwrap_or ( "state" . to_string ( ) ) ;
479+
480+ let listener = bind_allowed_port ( CALLBACK_PORTS ) . await ?;
481+ let port = listener. local_addr ( ) ?. port ( ) ;
482+
483+ let redirect_base = format ! ( "http://localhost:{port}" ) ;
484+ info ! ( %port, %redirect_base, "Unified auth portal listening for callback" ) ;
485+
486+ let auth_url = build_auth_url ( & redirect_base, & state, & challenge) ;
487+
488+ crate :: util:: open:: open_url_async ( & auth_url)
489+ . await
490+ . map_err ( |e| AuthError :: OAuthCustomError ( format ! ( "Failed to open browser: {e}" ) ) ) ?;
491+
492+ let callback = wait_for_auth_callback ( listener, state. clone ( ) ) . await ?;
493+
494+ if let Some ( error) = & callback. error {
495+ let friendly_msg =
496+ format_user_friendly_error ( error, callback. error_description . as_deref ( ) , & callback. login_option ) ;
497+
498+ warn ! (
499+ "OAuth error for {}: {} - {}" ,
500+ callback. login_option, error, friendly_msg
501+ ) ;
502+
503+ return Err ( match callback. login_option . as_str ( ) {
504+ "google" | "github" => AuthError :: SocialAuthProviderFailure ( friendly_msg) ,
505+ _ => AuthError :: OAuthCustomError ( friendly_msg) ,
506+ } ) ;
507+ }
508+
509+ process_portal_callback ( db, callback, port, & verifier) . await
510+ }
511+
512+ fn format_user_friendly_error ( error_code : & str , description : Option < & str > , provider : & str ) -> String {
513+ let cleaned_description = description. map ( |d| {
514+ let first_part = d. split ( ';' ) . next ( ) . unwrap_or ( d) ;
515+ // Replace + with spaces (URL encoding)
516+ first_part. replace ( '+' , " " ) . trim ( ) . to_string ( )
517+ } ) ;
518+
519+ match error_code {
520+ "access_denied" => {
521+ format ! ( "{provider} denied access to Kiro. Please ensure you grant all required permissions." )
522+ } ,
523+ "invalid_request" => "Authentication failed due to an invalid request. Please try again." . to_string ( ) ,
524+ "unauthorized_client" => "The application is not authorized. Please contact support." . to_string ( ) ,
525+ "server_error" => {
526+ format ! ( "{provider} login is temporarily unavailable. Please try again later." )
527+ } ,
528+ "invalid_scope" => "The requested permissions are invalid. Please contact support." . to_string ( ) ,
529+ _ => {
530+ // For unknown errors, use cleaned description or a generic message
531+ cleaned_description. unwrap_or_else ( || format ! ( "Authentication failed: {error_code}. Please try again." ) )
532+ } ,
533+ }
534+ }
535+
536+ /// Build the authorization URL with all required parameters
537+ fn build_auth_url ( redirect_base : & str , state : & str , challenge : & str ) -> String {
538+ let is_internal = is_mwinit_available ( ) ;
539+ let internal_param = if is_internal { "&from_amazon_internal=true" } else { "" } ;
540+
541+ format ! (
542+ "{}?state={}&code_challenge={}&code_challenge_method=S256&redirect_uri={}{}&redirect_from=kirocli" ,
543+ AUTH_PORTAL_URL ,
544+ state,
545+ challenge,
546+ urlencoding:: encode( redirect_base) ,
547+ internal_param
548+ )
549+ }
550+
551+ async fn process_portal_callback (
552+ db : & mut Database ,
553+ callback : AuthPortalCallback ,
554+ port : u16 ,
555+ verifier : & str ,
556+ ) -> Result < PortalResult , AuthError > {
557+ match callback. login_option . as_str ( ) {
558+ "google" | "github" => handle_social_callback ( db, callback, port, verifier) . await ,
559+ "internal" => {
560+ let ( issuer_url, sso_region) = extract_sso_params ( & callback, "internal" ) ?;
561+ Ok ( PortalResult :: Internal {
562+ issuer_url,
563+ idc_region : sso_region,
564+ } )
565+ } ,
566+ "awsidc" => {
567+ let ( issuer_url, sso_region) = extract_sso_params ( & callback, "awsIdc" ) ?;
568+ Ok ( PortalResult :: AwsIdc {
569+ issuer_url,
570+ idc_region : sso_region,
571+ } )
572+ } ,
573+ "builderid" => {
574+ let ( issuer_url, sso_region) = extract_sso_params ( & callback, "builderId" ) ?;
575+ Ok ( PortalResult :: BuilderId {
576+ issuer_url,
577+ idc_region : sso_region,
578+ } )
579+ } ,
580+ other => Err ( AuthError :: OAuthCustomError ( format ! ( "Unknown login_option: {other}" ) ) ) ,
581+ }
582+ }
583+
584+ /// Handle social provider callback (Google/GitHub)
585+ async fn handle_social_callback (
586+ db : & mut Database ,
587+ callback : AuthPortalCallback ,
588+ port : u16 ,
589+ verifier : & str ,
590+ ) -> Result < PortalResult , AuthError > {
591+ let provider = match callback. login_option . as_str ( ) {
592+ "google" => SocialProvider :: Google ,
593+ "github" => SocialProvider :: Github ,
594+ _ => unreachable ! ( ) ,
595+ } ;
596+
597+ let code = callback. code . ok_or ( AuthError :: OAuthMissingCode ) ?;
598+ let redirect_uri = format ! (
599+ "http://localhost:{}{}?login_option={}" ,
600+ port,
601+ callback. path,
602+ urlencoding:: encode( & callback. login_option)
603+ ) ;
604+
605+ SocialToken :: exchange_social_token ( db, provider, verifier, & code, & redirect_uri) . await ?;
606+ Ok ( PortalResult :: Social ( provider) )
607+ }
608+
609+ /// Extract issuer_url and sso_region from callback, returning descriptive error if missing
610+ fn extract_sso_params ( callback : & AuthPortalCallback , auth_type : & str ) -> Result < ( String , String ) , AuthError > {
611+ let issuer_url = callback
612+ . issuer_url
613+ . clone ( )
614+ . ok_or_else ( || AuthError :: OAuthCustomError ( format ! ( "Missing issuer_url for {auth_type} auth" ) ) ) ?;
615+
616+ let sso_region = callback
617+ . sso_region
618+ . clone ( )
619+ . ok_or_else ( || AuthError :: OAuthCustomError ( format ! ( "Missing sso_region for {auth_type} auth" ) ) ) ?;
620+
621+ Ok ( ( issuer_url, sso_region) )
622+ }
623+
624+ async fn wait_for_auth_callback (
625+ listener : TcpListener ,
626+ expected_state : String ,
627+ ) -> Result < AuthPortalCallback , AuthError > {
628+ let ( tx, mut rx) = tokio:: sync:: mpsc:: channel :: < AuthPortalCallback > ( 1 ) ;
629+
630+ let server_handle = tokio:: spawn ( async move {
631+ const MAX_CONNECTIONS : usize = 3 ;
632+ let mut count = 0 ;
633+
634+ loop {
635+ if count >= MAX_CONNECTIONS {
636+ warn ! ( "Reached max connections ({})" , MAX_CONNECTIONS ) ;
637+ break ;
638+ }
639+
640+ match listener. accept ( ) . await {
641+ Ok ( ( stream, _) ) => {
642+ count += 1 ;
643+ debug ! ( "Connection {}/{}" , count, MAX_CONNECTIONS ) ;
644+
645+ let io = TokioIo :: new ( stream) ;
646+ let service = AuthCallbackService { tx : tx. clone ( ) } ;
647+
648+ tokio:: spawn ( async move {
649+ let _ = http1:: Builder :: new ( ) . serve_connection ( io, service) . await ;
650+ } ) ;
651+ } ,
652+ Err ( e) => {
653+ error ! ( "Accept failed: {}" , e) ;
654+ break ;
655+ } ,
656+ }
657+ }
658+ } ) ;
659+
660+ let callback = tokio:: select! {
661+ result = rx. recv( ) => {
662+ result. ok_or( AuthError :: OAuthCustomError ( "Failed to receive callback" . into( ) ) ) ?
663+ } ,
664+ _ = tokio:: time:: sleep( DEFAULT_AUTHORIZATION_TIMEOUT ) => {
665+ return Err ( AuthError :: OAuthTimeout ) ;
666+ }
667+ } ;
668+
669+ server_handle. abort ( ) ;
670+
671+ if callback. state != expected_state {
672+ return Err ( AuthError :: OAuthStateMismatch {
673+ actual : callback. state ,
674+ expected : expected_state,
675+ } ) ;
676+ }
677+
678+ Ok ( callback)
679+ }
680+
681+ #[ derive( Clone ) ]
682+ struct AuthCallbackService {
683+ tx : tokio:: sync:: mpsc:: Sender < AuthPortalCallback > ,
684+ }
685+
686+ impl Service < Request < Incoming > > for AuthCallbackService {
687+ type Error = AuthError ;
688+ type Future = std:: pin:: Pin < Box < dyn std:: future:: Future < Output = Result < Self :: Response , Self :: Error > > + Send > > ;
689+ type Response = Response < Full < Bytes > > ;
690+
691+ fn call ( & self , req : Request < Incoming > ) -> Self :: Future {
692+ let tx = self . tx . clone ( ) ;
693+
694+ Box :: pin ( async move {
695+ let uri = req. uri ( ) ;
696+ let path = uri. path ( ) ;
697+
698+ if path == "/oauth/callback" || path == "/signin/callback" {
699+ handle_valid_callback ( uri, path, tx) . await
700+ } else {
701+ handle_invalid_callback ( path) . await
702+ }
703+ } )
704+ }
705+ }
706+
707+ /// Handle valid callback paths
708+ async fn handle_valid_callback (
709+ uri : & hyper:: Uri ,
710+ path : & str ,
711+ tx : tokio:: sync:: mpsc:: Sender < AuthPortalCallback > ,
712+ ) -> Result < Response < Full < Bytes > > , AuthError > {
713+ let query_params = uri
714+ . query ( )
715+ . map ( |query| {
716+ query
717+ . split ( '&' )
718+ . filter_map ( |kv| {
719+ kv. split_once ( '=' )
720+ . map ( |( k, v) | ( k. to_string ( ) , urlencoding:: decode ( v) . unwrap_or_default ( ) . to_string ( ) ) )
721+ } )
722+ . collect :: < std:: collections:: HashMap < String , String > > ( ) //
723+ } )
724+ . ok_or ( AuthError :: OAuthCustomError ( "query parameters are missing" . into ( ) ) ) ?;
725+
726+ let callback = AuthPortalCallback {
727+ login_option : query_params. get ( "login_option" ) . cloned ( ) . unwrap_or_default ( ) ,
728+ code : query_params. get ( "code" ) . cloned ( ) ,
729+ issuer_url : query_params. get ( "issuer_url" ) . cloned ( ) ,
730+ sso_region : query_params. get ( "idc_region" ) . cloned ( ) ,
731+ state : query_params. get ( "state" ) . cloned ( ) . unwrap_or_default ( ) ,
732+ path : path. to_string ( ) ,
733+ error : query_params. get ( "error" ) . cloned ( ) ,
734+ error_description : query_params. get ( "error_description" ) . cloned ( ) ,
735+ } ;
736+
737+ let _ = tx. send ( callback. clone ( ) ) . await ;
738+
739+ if let Some ( error) = & callback. error {
740+ let error_msg = callback. error_description . as_deref ( ) . unwrap_or ( error. as_str ( ) ) ;
741+ build_redirect_response ( "error" , Some ( error_msg) )
742+ } else {
743+ build_redirect_response ( "success" , None )
744+ }
745+ }
746+
747+ async fn handle_invalid_callback ( path : & str ) -> Result < Response < Full < Bytes > > , AuthError > {
748+ info ! ( %path, "Invalid callback path: {}, redirecting to portal" , path) ;
749+ build_redirect_response ( "error" , Some ( "Invalid callback path" ) )
750+ }
751+
752+ /// Build a redirect response to the auth portal
753+ fn build_redirect_response ( status : & str , error_message : Option < & str > ) -> Result < Response < Full < Bytes > > , AuthError > {
754+ let mut redirect_url = format ! ( "{AUTH_PORTAL_URL}?auth_status={status}&redirect_from=kirocli" ) ;
755+
756+ if let Some ( msg) = error_message {
757+ redirect_url. push_str ( & format ! ( "&error_message={}" , urlencoding:: encode( msg) ) ) ;
758+ }
759+
760+ Ok ( Response :: builder ( )
761+ . status ( 302 )
762+ . header ( "Location" , redirect_url)
763+ . header ( "Cache-Control" , "no-store" )
764+ . body ( Full :: new ( Bytes :: from ( "" ) ) )
765+ . expect ( "valid response" ) )
766+ }
767+
768+ async fn bind_allowed_port ( ports : & [ u16 ] ) -> Result < TcpListener , AuthError > {
769+ for port in ports {
770+ match TcpListener :: bind ( ( "127.0.0.1" , * port) ) . await {
771+ Ok ( listener) => return Ok ( listener) ,
772+ Err ( e) => {
773+ debug ! ( "Failed to bind to port {}: {}" , port, e) ;
774+ } ,
775+ }
776+ }
777+
778+ Err ( AuthError :: OAuthCustomError (
779+ "All callback ports are in use. Please close some applications and try again." . into ( ) ,
780+ ) )
781+ }
0 commit comments