Skip to content

Commit a50abba

Browse files
committed
add flexible env for auth url
2 parents 000c0ff + d8fef18 commit a50abba

File tree

3 files changed

+388
-5
lines changed

3 files changed

+388
-5
lines changed

crates/chat-cli/src/auth/portal.rs

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,387 @@ async fn bind_allowed_port(ports: &[u16]) -> Result<TcpListener, AuthError> {
395395
fn get_auth_portal_url() -> String {
396396
env::var("KIRO_AUTH_PORTAL_URL").unwrap_or_else(|_| DEFAULT_AUTH_PORTAL_URL.to_string())
397397
}
398+
//! Unified auth portal integration for streamlined authentication
399+
//! Handles callbacks from https://app.kiro.dev/signin
400+
401+
use std::time::Duration;
402+
403+
use bytes::Bytes;
404+
use http_body_util::Full;
405+
use hyper::body::Incoming;
406+
use hyper::server::conn::http1;
407+
use hyper::service::Service;
408+
use hyper::{
409+
Request,
410+
Response,
411+
};
412+
use hyper_util::rt::TokioIo;
413+
use rand::Rng;
414+
use tokio::net::TcpListener;
415+
use tracing::{
416+
debug,
417+
error,
418+
info,
419+
warn,
420+
};
421+
422+
use crate::auth::AuthError;
423+
use crate::auth::pkce::{
424+
generate_code_challenge,
425+
generate_code_verifier,
426+
};
427+
use crate::auth::social::{
428+
CALLBACK_PORTS,
429+
SocialProvider,
430+
SocialToken,
431+
};
432+
use crate::database::Database;
433+
use crate::util::system_info::is_mwinit_available;
434+
435+
const AUTH_PORTAL_URL: &str = "https://app.kiro.dev/signin";
436+
const DEFAULT_AUTHORIZATION_TIMEOUT: Duration = Duration::from_secs(600);
437+
438+
#[derive(Debug, Clone)]
439+
struct AuthPortalCallback {
440+
login_option: String,
441+
code: Option<String>,
442+
issuer_url: Option<String>,
443+
sso_region: Option<String>,
444+
state: String,
445+
path: String,
446+
error: Option<String>,
447+
error_description: Option<String>,
448+
}
449+
450+
pub enum PortalResult {
451+
Social(SocialProvider),
452+
BuilderId {
453+
issuer_url: String,
454+
idc_region: String,
455+
},
456+
AwsIdc {
457+
issuer_url: String,
458+
idc_region: String,
459+
},
460+
/// Internal amazon user
461+
Internal {
462+
issuer_url: String,
463+
idc_region: String,
464+
},
465+
}
466+
467+
/// Local-only: open unified portal and handle single callback
468+
pub async fn start_unified_auth(db: &mut Database) -> Result<PortalResult, AuthError> {
469+
info!("Starting unified auth portal flow");
470+
471+
// PKCE params for portal + social token exchange
472+
let verifier = generate_code_verifier();
473+
let challenge = generate_code_challenge(&verifier);
474+
let state = rand::rng()
475+
.sample_iter(rand::distr::Alphanumeric)
476+
.take(10)
477+
.collect::<Vec<_>>();
478+
let state = String::from_utf8(state).unwrap_or("state".to_string());
479+
480+
let listener = bind_allowed_port(CALLBACK_PORTS).await?;
481+
let port = listener.local_addr()?.port();
482+
483+
let redirect_base = format!("http://localhost:{port}");
484+
info!(%port, %redirect_base, "Unified auth portal listening for callback");
485+
486+
let auth_url = build_auth_url(&redirect_base, &state, &challenge);
487+
488+
crate::util::open::open_url_async(&auth_url)
489+
.await
490+
.map_err(|e| AuthError::OAuthCustomError(format!("Failed to open browser: {e}")))?;
491+
492+
let callback = wait_for_auth_callback(listener, state.clone()).await?;
493+
494+
if let Some(error) = &callback.error {
495+
let friendly_msg =
496+
format_user_friendly_error(error, callback.error_description.as_deref(), &callback.login_option);
497+
498+
warn!(
499+
"OAuth error for {}: {} - {}",
500+
callback.login_option, error, friendly_msg
501+
);
502+
503+
return Err(match callback.login_option.as_str() {
504+
"google" | "github" => AuthError::SocialAuthProviderFailure(friendly_msg),
505+
_ => AuthError::OAuthCustomError(friendly_msg),
506+
});
507+
}
508+
509+
process_portal_callback(db, callback, port, &verifier).await
510+
}
511+
512+
fn format_user_friendly_error(error_code: &str, description: Option<&str>, provider: &str) -> String {
513+
let cleaned_description = description.map(|d| {
514+
let first_part = d.split(';').next().unwrap_or(d);
515+
// Replace + with spaces (URL encoding)
516+
first_part.replace('+', " ").trim().to_string()
517+
});
518+
519+
match error_code {
520+
"access_denied" => {
521+
format!("{provider} denied access to Kiro. Please ensure you grant all required permissions.")
522+
},
523+
"invalid_request" => "Authentication failed due to an invalid request. Please try again.".to_string(),
524+
"unauthorized_client" => "The application is not authorized. Please contact support.".to_string(),
525+
"server_error" => {
526+
format!("{provider} login is temporarily unavailable. Please try again later.")
527+
},
528+
"invalid_scope" => "The requested permissions are invalid. Please contact support.".to_string(),
529+
_ => {
530+
// For unknown errors, use cleaned description or a generic message
531+
cleaned_description.unwrap_or_else(|| format!("Authentication failed: {error_code}. Please try again."))
532+
},
533+
}
534+
}
535+
536+
/// Build the authorization URL with all required parameters
537+
fn build_auth_url(redirect_base: &str, state: &str, challenge: &str) -> String {
538+
let is_internal = is_mwinit_available();
539+
let internal_param = if is_internal { "&from_amazon_internal=true" } else { "" };
540+
541+
format!(
542+
"{}?state={}&code_challenge={}&code_challenge_method=S256&redirect_uri={}{}&redirect_from=kirocli",
543+
AUTH_PORTAL_URL,
544+
state,
545+
challenge,
546+
urlencoding::encode(redirect_base),
547+
internal_param
548+
)
549+
}
550+
551+
async fn process_portal_callback(
552+
db: &mut Database,
553+
callback: AuthPortalCallback,
554+
port: u16,
555+
verifier: &str,
556+
) -> Result<PortalResult, AuthError> {
557+
match callback.login_option.as_str() {
558+
"google" | "github" => handle_social_callback(db, callback, port, verifier).await,
559+
"internal" => {
560+
let (issuer_url, sso_region) = extract_sso_params(&callback, "internal")?;
561+
Ok(PortalResult::Internal {
562+
issuer_url,
563+
idc_region: sso_region,
564+
})
565+
},
566+
"awsidc" => {
567+
let (issuer_url, sso_region) = extract_sso_params(&callback, "awsIdc")?;
568+
Ok(PortalResult::AwsIdc {
569+
issuer_url,
570+
idc_region: sso_region,
571+
})
572+
},
573+
"builderid" => {
574+
let (issuer_url, sso_region) = extract_sso_params(&callback, "builderId")?;
575+
Ok(PortalResult::BuilderId {
576+
issuer_url,
577+
idc_region: sso_region,
578+
})
579+
},
580+
other => Err(AuthError::OAuthCustomError(format!("Unknown login_option: {other}"))),
581+
}
582+
}
583+
584+
/// Handle social provider callback (Google/GitHub)
585+
async fn handle_social_callback(
586+
db: &mut Database,
587+
callback: AuthPortalCallback,
588+
port: u16,
589+
verifier: &str,
590+
) -> Result<PortalResult, AuthError> {
591+
let provider = match callback.login_option.as_str() {
592+
"google" => SocialProvider::Google,
593+
"github" => SocialProvider::Github,
594+
_ => unreachable!(),
595+
};
596+
597+
let code = callback.code.ok_or(AuthError::OAuthMissingCode)?;
598+
let redirect_uri = format!(
599+
"http://localhost:{}{}?login_option={}",
600+
port,
601+
callback.path,
602+
urlencoding::encode(&callback.login_option)
603+
);
604+
605+
SocialToken::exchange_social_token(db, provider, verifier, &code, &redirect_uri).await?;
606+
Ok(PortalResult::Social(provider))
607+
}
608+
609+
/// Extract issuer_url and sso_region from callback, returning descriptive error if missing
610+
fn extract_sso_params(callback: &AuthPortalCallback, auth_type: &str) -> Result<(String, String), AuthError> {
611+
let issuer_url = callback
612+
.issuer_url
613+
.clone()
614+
.ok_or_else(|| AuthError::OAuthCustomError(format!("Missing issuer_url for {auth_type} auth")))?;
615+
616+
let sso_region = callback
617+
.sso_region
618+
.clone()
619+
.ok_or_else(|| AuthError::OAuthCustomError(format!("Missing sso_region for {auth_type} auth")))?;
620+
621+
Ok((issuer_url, sso_region))
622+
}
623+
624+
async fn wait_for_auth_callback(
625+
listener: TcpListener,
626+
expected_state: String,
627+
) -> Result<AuthPortalCallback, AuthError> {
628+
let (tx, mut rx) = tokio::sync::mpsc::channel::<AuthPortalCallback>(1);
629+
630+
let server_handle = tokio::spawn(async move {
631+
const MAX_CONNECTIONS: usize = 3;
632+
let mut count = 0;
633+
634+
loop {
635+
if count >= MAX_CONNECTIONS {
636+
warn!("Reached max connections ({})", MAX_CONNECTIONS);
637+
break;
638+
}
639+
640+
match listener.accept().await {
641+
Ok((stream, _)) => {
642+
count += 1;
643+
debug!("Connection {}/{}", count, MAX_CONNECTIONS);
644+
645+
let io = TokioIo::new(stream);
646+
let service = AuthCallbackService { tx: tx.clone() };
647+
648+
tokio::spawn(async move {
649+
let _ = http1::Builder::new().serve_connection(io, service).await;
650+
});
651+
},
652+
Err(e) => {
653+
error!("Accept failed: {}", e);
654+
break;
655+
},
656+
}
657+
}
658+
});
659+
660+
let callback = tokio::select! {
661+
result = rx.recv() => {
662+
result.ok_or(AuthError::OAuthCustomError("Failed to receive callback".into()))?
663+
},
664+
_ = tokio::time::sleep(DEFAULT_AUTHORIZATION_TIMEOUT) => {
665+
return Err(AuthError::OAuthTimeout);
666+
}
667+
};
668+
669+
server_handle.abort();
670+
671+
if callback.state != expected_state {
672+
return Err(AuthError::OAuthStateMismatch {
673+
actual: callback.state,
674+
expected: expected_state,
675+
});
676+
}
677+
678+
Ok(callback)
679+
}
680+
681+
#[derive(Clone)]
682+
struct AuthCallbackService {
683+
tx: tokio::sync::mpsc::Sender<AuthPortalCallback>,
684+
}
685+
686+
impl Service<Request<Incoming>> for AuthCallbackService {
687+
type Error = AuthError;
688+
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
689+
type Response = Response<Full<Bytes>>;
690+
691+
fn call(&self, req: Request<Incoming>) -> Self::Future {
692+
let tx = self.tx.clone();
693+
694+
Box::pin(async move {
695+
let uri = req.uri();
696+
let path = uri.path();
697+
698+
if path == "/oauth/callback" || path == "/signin/callback" {
699+
handle_valid_callback(uri, path, tx).await
700+
} else {
701+
handle_invalid_callback(path).await
702+
}
703+
})
704+
}
705+
}
706+
707+
/// Handle valid callback paths
708+
async fn handle_valid_callback(
709+
uri: &hyper::Uri,
710+
path: &str,
711+
tx: tokio::sync::mpsc::Sender<AuthPortalCallback>,
712+
) -> Result<Response<Full<Bytes>>, AuthError> {
713+
let query_params = uri
714+
.query()
715+
.map(|query| {
716+
query
717+
.split('&')
718+
.filter_map(|kv| {
719+
kv.split_once('=')
720+
.map(|(k, v)| (k.to_string(), urlencoding::decode(v).unwrap_or_default().to_string()))
721+
})
722+
.collect::<std::collections::HashMap<String, String>>() //
723+
})
724+
.ok_or(AuthError::OAuthCustomError("query parameters are missing".into()))?;
725+
726+
let callback = AuthPortalCallback {
727+
login_option: query_params.get("login_option").cloned().unwrap_or_default(),
728+
code: query_params.get("code").cloned(),
729+
issuer_url: query_params.get("issuer_url").cloned(),
730+
sso_region: query_params.get("idc_region").cloned(),
731+
state: query_params.get("state").cloned().unwrap_or_default(),
732+
path: path.to_string(),
733+
error: query_params.get("error").cloned(),
734+
error_description: query_params.get("error_description").cloned(),
735+
};
736+
737+
let _ = tx.send(callback.clone()).await;
738+
739+
if let Some(error) = &callback.error {
740+
let error_msg = callback.error_description.as_deref().unwrap_or(error.as_str());
741+
build_redirect_response("error", Some(error_msg))
742+
} else {
743+
build_redirect_response("success", None)
744+
}
745+
}
746+
747+
async fn handle_invalid_callback(path: &str) -> Result<Response<Full<Bytes>>, AuthError> {
748+
info!(%path, "Invalid callback path: {}, redirecting to portal", path);
749+
build_redirect_response("error", Some("Invalid callback path"))
750+
}
751+
752+
/// Build a redirect response to the auth portal
753+
fn build_redirect_response(status: &str, error_message: Option<&str>) -> Result<Response<Full<Bytes>>, AuthError> {
754+
let mut redirect_url = format!("{AUTH_PORTAL_URL}?auth_status={status}&redirect_from=kirocli");
755+
756+
if let Some(msg) = error_message {
757+
redirect_url.push_str(&format!("&error_message={}", urlencoding::encode(msg)));
758+
}
759+
760+
Ok(Response::builder()
761+
.status(302)
762+
.header("Location", redirect_url)
763+
.header("Cache-Control", "no-store")
764+
.body(Full::new(Bytes::from("")))
765+
.expect("valid response"))
766+
}
767+
768+
async fn bind_allowed_port(ports: &[u16]) -> Result<TcpListener, AuthError> {
769+
for port in ports {
770+
match TcpListener::bind(("127.0.0.1", *port)).await {
771+
Ok(listener) => return Ok(listener),
772+
Err(e) => {
773+
debug!("Failed to bind to port {}: {}", port, e);
774+
},
775+
}
776+
}
777+
778+
Err(AuthError::OAuthCustomError(
779+
"All callback ports are in use. Please close some applications and try again.".into(),
780+
))
781+
}

0 commit comments

Comments
 (0)