@@ -189,23 +189,32 @@ def get_credentials(
189
189
# Check if we're in IRSA and trying to assume the same role we already have
190
190
current_role_arn = os .getenv ("AWS_ROLE_ARN" )
191
191
web_identity_token_file = os .getenv ("AWS_WEB_IDENTITY_TOKEN_FILE" )
192
-
192
+
193
193
# In IRSA environments, we should skip role assumption if we're already running as the target role
194
194
# This is true when:
195
195
# 1. We have AWS_ROLE_ARN set (current role)
196
196
# 2. We have AWS_WEB_IDENTITY_TOKEN_FILE set (IRSA environment)
197
197
# 3. The current role matches the requested role
198
- if (current_role_arn and web_identity_token_file and
199
- current_role_arn == aws_role_name ):
200
- verbose_logger .debug ("Using IRSA same-role optimization: calling _auth_with_env_vars" )
198
+ if (
199
+ current_role_arn
200
+ and web_identity_token_file
201
+ and current_role_arn == aws_role_name
202
+ ):
203
+ verbose_logger .debug (
204
+ "Using IRSA same-role optimization: calling _auth_with_env_vars"
205
+ )
201
206
# We're already running as this role via IRSA, no need to assume it again
202
207
# Use the default boto3 credentials (which will use the IRSA credentials)
203
208
credentials , _cache_ttl = self ._auth_with_env_vars ()
204
209
else :
205
- verbose_logger .debug ("Using role assumption: calling _auth_with_aws_role" )
210
+ verbose_logger .debug (
211
+ "Using role assumption: calling _auth_with_aws_role"
212
+ )
206
213
# If aws_session_name is not provided, generate a default one
207
214
if aws_session_name is None :
208
- aws_session_name = f"litellm-session-{ int (datetime .now ().timestamp ())} "
215
+ aws_session_name = (
216
+ f"litellm-session-{ int (datetime .now ().timestamp ())} "
217
+ )
209
218
credentials , _cache_ttl = self ._auth_with_aws_role (
210
219
aws_access_key_id = aws_access_key_id ,
211
220
aws_secret_access_key = aws_secret_access_key ,
@@ -479,55 +488,67 @@ def _auth_with_web_identity_token(
479
488
iam_creds = session .get_credentials ()
480
489
return iam_creds , self ._get_default_ttl_for_boto3_credentials ()
481
490
482
- def _handle_irsa_cross_account (self , irsa_role_arn : str , aws_role_name : str ,
483
- aws_session_name : str , region : str , web_identity_token_file : str ,
484
- aws_external_id : Optional [str ] = None ) -> dict :
491
+ def _handle_irsa_cross_account (
492
+ self ,
493
+ irsa_role_arn : str ,
494
+ aws_role_name : str ,
495
+ aws_session_name : str ,
496
+ region : str ,
497
+ web_identity_token_file : str ,
498
+ aws_external_id : Optional [str ] = None ,
499
+ ) -> dict :
485
500
"""Handle cross-account role assumption for IRSA."""
486
501
import boto3
487
-
502
+
488
503
verbose_logger .debug ("Cross-account role assumption detected" )
489
-
504
+
490
505
# Read the web identity token
491
- with open (web_identity_token_file , 'r' ) as f :
506
+ with open (web_identity_token_file , "r" ) as f :
492
507
web_identity_token = f .read ().strip ()
493
-
508
+
494
509
# Create an STS client without credentials
495
510
with tracer .trace ("boto3.client(sts) for manual IRSA" ):
496
- sts_client = boto3 .client (' sts' , region_name = region )
497
-
511
+ sts_client = boto3 .client (" sts" , region_name = region )
512
+
498
513
# Manually assume the IRSA role with the session name
499
- verbose_logger .debug (f"Manually assuming IRSA role { irsa_role_arn } with session { aws_session_name } " )
514
+ verbose_logger .debug (
515
+ f"Manually assuming IRSA role { irsa_role_arn } with session { aws_session_name } "
516
+ )
500
517
irsa_response = sts_client .assume_role_with_web_identity (
501
518
RoleArn = irsa_role_arn ,
502
519
RoleSessionName = aws_session_name ,
503
- WebIdentityToken = web_identity_token
520
+ WebIdentityToken = web_identity_token ,
504
521
)
505
-
522
+
506
523
# Extract the credentials from the IRSA assumption
507
524
irsa_creds = irsa_response ["Credentials" ]
508
-
525
+
509
526
# Create a new STS client with the IRSA credentials
510
527
with tracer .trace ("boto3.client(sts) with manual IRSA credentials" ):
511
528
sts_client_with_creds = boto3 .client (
512
- ' sts' ,
529
+ " sts" ,
513
530
region_name = region ,
514
531
aws_access_key_id = irsa_creds ["AccessKeyId" ],
515
532
aws_secret_access_key = irsa_creds ["SecretAccessKey" ],
516
- aws_session_token = irsa_creds ["SessionToken" ]
533
+ aws_session_token = irsa_creds ["SessionToken" ],
517
534
)
518
-
535
+
519
536
# Get current caller identity for debugging
520
537
try :
521
538
caller_identity = sts_client_with_creds .get_caller_identity ()
522
- verbose_logger .debug (f"Current identity after manual IRSA assumption: { caller_identity .get ('Arn' , 'unknown' )} " )
539
+ verbose_logger .debug (
540
+ f"Current identity after manual IRSA assumption: { caller_identity .get ('Arn' , 'unknown' )} "
541
+ )
523
542
except Exception as e :
524
543
verbose_logger .debug (f"Failed to get caller identity: { e } " )
525
-
544
+
526
545
# Now assume the target role
527
- verbose_logger .debug (f"Attempting to assume target role: { aws_role_name } with session: { aws_session_name } " )
546
+ verbose_logger .debug (
547
+ f"Attempting to assume target role: { aws_role_name } with session: { aws_session_name } "
548
+ )
528
549
assume_role_params = {
529
550
"RoleArn" : aws_role_name ,
530
- "RoleSessionName" : aws_session_name
551
+ "RoleSessionName" : aws_session_name ,
531
552
}
532
553
533
554
# Add ExternalId parameter if provided
@@ -536,27 +557,36 @@ def _handle_irsa_cross_account(self, irsa_role_arn: str, aws_role_name: str,
536
557
537
558
return sts_client_with_creds .assume_role (** assume_role_params )
538
559
539
- def _handle_irsa_same_account (self , aws_role_name : str , aws_session_name : str , region : str ,
540
- aws_external_id : Optional [str ] = None ) -> dict :
560
+ def _handle_irsa_same_account (
561
+ self ,
562
+ aws_role_name : str ,
563
+ aws_session_name : str ,
564
+ region : str ,
565
+ aws_external_id : Optional [str ] = None ,
566
+ ) -> dict :
541
567
"""Handle same-account role assumption for IRSA."""
542
568
import boto3
543
-
569
+
544
570
verbose_logger .debug ("Same account role assumption, using automatic IRSA" )
545
571
with tracer .trace ("boto3.client(sts) with automatic IRSA" ):
546
572
sts_client = boto3 .client ("sts" , region_name = region )
547
-
573
+
548
574
# Get current caller identity for debugging
549
575
try :
550
576
caller_identity = sts_client .get_caller_identity ()
551
- verbose_logger .debug (f"Current IRSA identity: { caller_identity .get ('Arn' , 'unknown' )} " )
577
+ verbose_logger .debug (
578
+ f"Current IRSA identity: { caller_identity .get ('Arn' , 'unknown' )} "
579
+ )
552
580
except Exception as e :
553
581
verbose_logger .debug (f"Failed to get caller identity: { e } " )
554
-
582
+
555
583
# Assume the role
556
- verbose_logger .debug (f"Attempting to assume role: { aws_role_name } with session: { aws_session_name } " )
584
+ verbose_logger .debug (
585
+ f"Attempting to assume role: { aws_role_name } with session: { aws_session_name } "
586
+ )
557
587
assume_role_params = {
558
588
"RoleArn" : aws_role_name ,
559
- "RoleSessionName" : aws_session_name
589
+ "RoleSessionName" : aws_session_name ,
560
590
}
561
591
562
592
# Add ExternalId parameter if provided
@@ -565,20 +595,24 @@ def _handle_irsa_same_account(self, aws_role_name: str, aws_session_name: str, r
565
595
566
596
return sts_client .assume_role (** assume_role_params )
567
597
568
- def _extract_credentials_and_ttl (self , sts_response : dict ) -> Tuple [Credentials , Optional [int ]]:
598
+ def _extract_credentials_and_ttl (
599
+ self , sts_response : dict
600
+ ) -> Tuple [Credentials , Optional [int ]]:
569
601
"""Extract credentials and TTL from STS response."""
570
602
from botocore .credentials import Credentials
571
-
603
+
572
604
sts_credentials = sts_response ["Credentials" ]
573
605
credentials = Credentials (
574
606
access_key = sts_credentials ["AccessKeyId" ],
575
607
secret_key = sts_credentials ["SecretAccessKey" ],
576
608
token = sts_credentials ["SessionToken" ],
577
609
)
578
-
610
+
579
611
expiration_time = sts_credentials ["Expiration" ]
580
- ttl = int ((expiration_time - datetime .now (expiration_time .tzinfo )).total_seconds ())
581
-
612
+ ttl = int (
613
+ (expiration_time - datetime .now (expiration_time .tzinfo )).total_seconds ()
614
+ )
615
+
582
616
return credentials , ttl
583
617
584
618
@tracer .wrap ()
@@ -600,34 +634,51 @@ def _auth_with_aws_role(
600
634
# Check if we're in an EKS/IRSA environment
601
635
web_identity_token_file = os .getenv ("AWS_WEB_IDENTITY_TOKEN_FILE" )
602
636
irsa_role_arn = os .getenv ("AWS_ROLE_ARN" )
603
-
637
+
604
638
# If we have IRSA environment variables and no explicit credentials,
605
639
# we need to use the web identity token flow
606
- if (web_identity_token_file and irsa_role_arn and
607
- aws_access_key_id is None and aws_secret_access_key is None ):
640
+ if (
641
+ web_identity_token_file
642
+ and irsa_role_arn
643
+ and aws_access_key_id is None
644
+ and aws_secret_access_key is None
645
+ ):
608
646
# For cross-account role assumption with specific session names,
609
647
# we need to manually assume the IRSA role first with the correct session name
610
- verbose_logger .debug (f"IRSA detected: using web identity token from { web_identity_token_file } " )
611
-
648
+ verbose_logger .debug (
649
+ f"IRSA detected: using web identity token from { web_identity_token_file } "
650
+ )
651
+
612
652
try :
613
653
# Get region from environment
614
- region = os .getenv ("AWS_REGION" ) or os .getenv ("AWS_DEFAULT_REGION" ) or "us-east-1"
615
-
654
+ region = (
655
+ os .getenv ("AWS_REGION" )
656
+ or os .getenv ("AWS_DEFAULT_REGION" )
657
+ or "us-east-1"
658
+ )
659
+
616
660
# Check if we need to do cross-account role assumption
617
661
if aws_role_name != irsa_role_arn :
618
662
sts_response = self ._handle_irsa_cross_account (
619
- irsa_role_arn , aws_role_name , aws_session_name , region , web_identity_token_file , aws_external_id
663
+ irsa_role_arn ,
664
+ aws_role_name ,
665
+ aws_session_name ,
666
+ region ,
667
+ web_identity_token_file ,
668
+ aws_external_id ,
620
669
)
621
670
else :
622
671
sts_response = self ._handle_irsa_same_account (
623
672
aws_role_name , aws_session_name , region , aws_external_id
624
673
)
625
-
674
+
626
675
return self ._extract_credentials_and_ttl (sts_response )
627
-
676
+
628
677
except Exception as e :
629
678
verbose_logger .debug (f"Failed to assume role via IRSA: { e } " )
630
- if "AccessDenied" in str (e ) and "is not authorized to perform: sts:AssumeRole" in str (e ):
679
+ if "AccessDenied" in str (
680
+ e
681
+ ) and "is not authorized to perform: sts:AssumeRole" in str (e ):
631
682
# Provide a more helpful error message for trust policy issues
632
683
verbose_logger .error (
633
684
f"Access denied when trying to assume role { aws_role_name } . "
@@ -636,7 +687,7 @@ def _auth_with_aws_role(
636
687
)
637
688
# Re-raise the exception instead of falling through
638
689
raise
639
-
690
+
640
691
# In EKS/IRSA environments, use ambient credentials (no explicit keys needed)
641
692
# This allows the web identity token to work automatically
642
693
if aws_access_key_id is None and aws_secret_access_key is None :
@@ -653,7 +704,7 @@ def _auth_with_aws_role(
653
704
654
705
assume_role_params = {
655
706
"RoleArn" : aws_role_name ,
656
- "RoleSessionName" : aws_session_name
707
+ "RoleSessionName" : aws_session_name ,
657
708
}
658
709
659
710
# Add ExternalId parameter if provided
@@ -782,14 +833,14 @@ def get_runtime_endpoint(
782
833
)
783
834
784
835
# Determine proxy_endpoint_url
785
- if env_aws_bedrock_runtime_endpoint and isinstance (
786
- env_aws_bedrock_runtime_endpoint , str
787
- ):
788
- proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
789
- elif aws_bedrock_runtime_endpoint is not None and isinstance (
836
+ if aws_bedrock_runtime_endpoint is not None and isinstance (
790
837
aws_bedrock_runtime_endpoint , str
791
838
):
792
839
proxy_endpoint_url = aws_bedrock_runtime_endpoint
840
+ elif env_aws_bedrock_runtime_endpoint and isinstance (
841
+ env_aws_bedrock_runtime_endpoint , str
842
+ ):
843
+ proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
793
844
else :
794
845
proxy_endpoint_url = endpoint_url
795
846
0 commit comments