Skip to content

Commit 530662d

Browse files
test: add unit testing
1 parent 1ada663 commit 530662d

File tree

2 files changed

+298
-57
lines changed

2 files changed

+298
-57
lines changed

litellm/llms/bedrock/base_aws_llm.py

Lines changed: 108 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -189,23 +189,32 @@ def get_credentials(
189189
# Check if we're in IRSA and trying to assume the same role we already have
190190
current_role_arn = os.getenv("AWS_ROLE_ARN")
191191
web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
192-
192+
193193
# In IRSA environments, we should skip role assumption if we're already running as the target role
194194
# This is true when:
195195
# 1. We have AWS_ROLE_ARN set (current role)
196196
# 2. We have AWS_WEB_IDENTITY_TOKEN_FILE set (IRSA environment)
197197
# 3. The current role matches the requested role
198-
if (current_role_arn and web_identity_token_file and
199-
current_role_arn == aws_role_name):
200-
verbose_logger.debug("Using IRSA same-role optimization: calling _auth_with_env_vars")
198+
if (
199+
current_role_arn
200+
and web_identity_token_file
201+
and current_role_arn == aws_role_name
202+
):
203+
verbose_logger.debug(
204+
"Using IRSA same-role optimization: calling _auth_with_env_vars"
205+
)
201206
# We're already running as this role via IRSA, no need to assume it again
202207
# Use the default boto3 credentials (which will use the IRSA credentials)
203208
credentials, _cache_ttl = self._auth_with_env_vars()
204209
else:
205-
verbose_logger.debug("Using role assumption: calling _auth_with_aws_role")
210+
verbose_logger.debug(
211+
"Using role assumption: calling _auth_with_aws_role"
212+
)
206213
# If aws_session_name is not provided, generate a default one
207214
if aws_session_name is None:
208-
aws_session_name = f"litellm-session-{int(datetime.now().timestamp())}"
215+
aws_session_name = (
216+
f"litellm-session-{int(datetime.now().timestamp())}"
217+
)
209218
credentials, _cache_ttl = self._auth_with_aws_role(
210219
aws_access_key_id=aws_access_key_id,
211220
aws_secret_access_key=aws_secret_access_key,
@@ -479,55 +488,67 @@ def _auth_with_web_identity_token(
479488
iam_creds = session.get_credentials()
480489
return iam_creds, self._get_default_ttl_for_boto3_credentials()
481490

482-
def _handle_irsa_cross_account(self, irsa_role_arn: str, aws_role_name: str,
483-
aws_session_name: str, region: str, web_identity_token_file: str,
484-
aws_external_id: Optional[str] = None) -> dict:
491+
def _handle_irsa_cross_account(
492+
self,
493+
irsa_role_arn: str,
494+
aws_role_name: str,
495+
aws_session_name: str,
496+
region: str,
497+
web_identity_token_file: str,
498+
aws_external_id: Optional[str] = None,
499+
) -> dict:
485500
"""Handle cross-account role assumption for IRSA."""
486501
import boto3
487-
502+
488503
verbose_logger.debug("Cross-account role assumption detected")
489-
504+
490505
# Read the web identity token
491-
with open(web_identity_token_file, 'r') as f:
506+
with open(web_identity_token_file, "r") as f:
492507
web_identity_token = f.read().strip()
493-
508+
494509
# Create an STS client without credentials
495510
with tracer.trace("boto3.client(sts) for manual IRSA"):
496-
sts_client = boto3.client('sts', region_name=region)
497-
511+
sts_client = boto3.client("sts", region_name=region)
512+
498513
# Manually assume the IRSA role with the session name
499-
verbose_logger.debug(f"Manually assuming IRSA role {irsa_role_arn} with session {aws_session_name}")
514+
verbose_logger.debug(
515+
f"Manually assuming IRSA role {irsa_role_arn} with session {aws_session_name}"
516+
)
500517
irsa_response = sts_client.assume_role_with_web_identity(
501518
RoleArn=irsa_role_arn,
502519
RoleSessionName=aws_session_name,
503-
WebIdentityToken=web_identity_token
520+
WebIdentityToken=web_identity_token,
504521
)
505-
522+
506523
# Extract the credentials from the IRSA assumption
507524
irsa_creds = irsa_response["Credentials"]
508-
525+
509526
# Create a new STS client with the IRSA credentials
510527
with tracer.trace("boto3.client(sts) with manual IRSA credentials"):
511528
sts_client_with_creds = boto3.client(
512-
'sts',
529+
"sts",
513530
region_name=region,
514531
aws_access_key_id=irsa_creds["AccessKeyId"],
515532
aws_secret_access_key=irsa_creds["SecretAccessKey"],
516-
aws_session_token=irsa_creds["SessionToken"]
533+
aws_session_token=irsa_creds["SessionToken"],
517534
)
518-
535+
519536
# Get current caller identity for debugging
520537
try:
521538
caller_identity = sts_client_with_creds.get_caller_identity()
522-
verbose_logger.debug(f"Current identity after manual IRSA assumption: {caller_identity.get('Arn', 'unknown')}")
539+
verbose_logger.debug(
540+
f"Current identity after manual IRSA assumption: {caller_identity.get('Arn', 'unknown')}"
541+
)
523542
except Exception as e:
524543
verbose_logger.debug(f"Failed to get caller identity: {e}")
525-
544+
526545
# Now assume the target role
527-
verbose_logger.debug(f"Attempting to assume target role: {aws_role_name} with session: {aws_session_name}")
546+
verbose_logger.debug(
547+
f"Attempting to assume target role: {aws_role_name} with session: {aws_session_name}"
548+
)
528549
assume_role_params = {
529550
"RoleArn": aws_role_name,
530-
"RoleSessionName": aws_session_name
551+
"RoleSessionName": aws_session_name,
531552
}
532553

533554
# Add ExternalId parameter if provided
@@ -536,27 +557,36 @@ def _handle_irsa_cross_account(self, irsa_role_arn: str, aws_role_name: str,
536557

537558
return sts_client_with_creds.assume_role(**assume_role_params)
538559

539-
def _handle_irsa_same_account(self, aws_role_name: str, aws_session_name: str, region: str,
540-
aws_external_id: Optional[str] = None) -> dict:
560+
def _handle_irsa_same_account(
561+
self,
562+
aws_role_name: str,
563+
aws_session_name: str,
564+
region: str,
565+
aws_external_id: Optional[str] = None,
566+
) -> dict:
541567
"""Handle same-account role assumption for IRSA."""
542568
import boto3
543-
569+
544570
verbose_logger.debug("Same account role assumption, using automatic IRSA")
545571
with tracer.trace("boto3.client(sts) with automatic IRSA"):
546572
sts_client = boto3.client("sts", region_name=region)
547-
573+
548574
# Get current caller identity for debugging
549575
try:
550576
caller_identity = sts_client.get_caller_identity()
551-
verbose_logger.debug(f"Current IRSA identity: {caller_identity.get('Arn', 'unknown')}")
577+
verbose_logger.debug(
578+
f"Current IRSA identity: {caller_identity.get('Arn', 'unknown')}"
579+
)
552580
except Exception as e:
553581
verbose_logger.debug(f"Failed to get caller identity: {e}")
554-
582+
555583
# Assume the role
556-
verbose_logger.debug(f"Attempting to assume role: {aws_role_name} with session: {aws_session_name}")
584+
verbose_logger.debug(
585+
f"Attempting to assume role: {aws_role_name} with session: {aws_session_name}"
586+
)
557587
assume_role_params = {
558588
"RoleArn": aws_role_name,
559-
"RoleSessionName": aws_session_name
589+
"RoleSessionName": aws_session_name,
560590
}
561591

562592
# Add ExternalId parameter if provided
@@ -565,20 +595,24 @@ def _handle_irsa_same_account(self, aws_role_name: str, aws_session_name: str, r
565595

566596
return sts_client.assume_role(**assume_role_params)
567597

568-
def _extract_credentials_and_ttl(self, sts_response: dict) -> Tuple[Credentials, Optional[int]]:
598+
def _extract_credentials_and_ttl(
599+
self, sts_response: dict
600+
) -> Tuple[Credentials, Optional[int]]:
569601
"""Extract credentials and TTL from STS response."""
570602
from botocore.credentials import Credentials
571-
603+
572604
sts_credentials = sts_response["Credentials"]
573605
credentials = Credentials(
574606
access_key=sts_credentials["AccessKeyId"],
575607
secret_key=sts_credentials["SecretAccessKey"],
576608
token=sts_credentials["SessionToken"],
577609
)
578-
610+
579611
expiration_time = sts_credentials["Expiration"]
580-
ttl = int((expiration_time - datetime.now(expiration_time.tzinfo)).total_seconds())
581-
612+
ttl = int(
613+
(expiration_time - datetime.now(expiration_time.tzinfo)).total_seconds()
614+
)
615+
582616
return credentials, ttl
583617

584618
@tracer.wrap()
@@ -600,34 +634,51 @@ def _auth_with_aws_role(
600634
# Check if we're in an EKS/IRSA environment
601635
web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
602636
irsa_role_arn = os.getenv("AWS_ROLE_ARN")
603-
637+
604638
# If we have IRSA environment variables and no explicit credentials,
605639
# we need to use the web identity token flow
606-
if (web_identity_token_file and irsa_role_arn and
607-
aws_access_key_id is None and aws_secret_access_key is None):
640+
if (
641+
web_identity_token_file
642+
and irsa_role_arn
643+
and aws_access_key_id is None
644+
and aws_secret_access_key is None
645+
):
608646
# For cross-account role assumption with specific session names,
609647
# we need to manually assume the IRSA role first with the correct session name
610-
verbose_logger.debug(f"IRSA detected: using web identity token from {web_identity_token_file}")
611-
648+
verbose_logger.debug(
649+
f"IRSA detected: using web identity token from {web_identity_token_file}"
650+
)
651+
612652
try:
613653
# Get region from environment
614-
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") or "us-east-1"
615-
654+
region = (
655+
os.getenv("AWS_REGION")
656+
or os.getenv("AWS_DEFAULT_REGION")
657+
or "us-east-1"
658+
)
659+
616660
# Check if we need to do cross-account role assumption
617661
if aws_role_name != irsa_role_arn:
618662
sts_response = self._handle_irsa_cross_account(
619-
irsa_role_arn, aws_role_name, aws_session_name, region, web_identity_token_file, aws_external_id
663+
irsa_role_arn,
664+
aws_role_name,
665+
aws_session_name,
666+
region,
667+
web_identity_token_file,
668+
aws_external_id,
620669
)
621670
else:
622671
sts_response = self._handle_irsa_same_account(
623672
aws_role_name, aws_session_name, region, aws_external_id
624673
)
625-
674+
626675
return self._extract_credentials_and_ttl(sts_response)
627-
676+
628677
except Exception as e:
629678
verbose_logger.debug(f"Failed to assume role via IRSA: {e}")
630-
if "AccessDenied" in str(e) and "is not authorized to perform: sts:AssumeRole" in str(e):
679+
if "AccessDenied" in str(
680+
e
681+
) and "is not authorized to perform: sts:AssumeRole" in str(e):
631682
# Provide a more helpful error message for trust policy issues
632683
verbose_logger.error(
633684
f"Access denied when trying to assume role {aws_role_name}. "
@@ -636,7 +687,7 @@ def _auth_with_aws_role(
636687
)
637688
# Re-raise the exception instead of falling through
638689
raise
639-
690+
640691
# In EKS/IRSA environments, use ambient credentials (no explicit keys needed)
641692
# This allows the web identity token to work automatically
642693
if aws_access_key_id is None and aws_secret_access_key is None:
@@ -653,7 +704,7 @@ def _auth_with_aws_role(
653704

654705
assume_role_params = {
655706
"RoleArn": aws_role_name,
656-
"RoleSessionName": aws_session_name
707+
"RoleSessionName": aws_session_name,
657708
}
658709

659710
# Add ExternalId parameter if provided
@@ -782,14 +833,14 @@ def get_runtime_endpoint(
782833
)
783834

784835
# Determine proxy_endpoint_url
785-
if env_aws_bedrock_runtime_endpoint and isinstance(
786-
env_aws_bedrock_runtime_endpoint, str
787-
):
788-
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
789-
elif aws_bedrock_runtime_endpoint is not None and isinstance(
836+
if aws_bedrock_runtime_endpoint is not None and isinstance(
790837
aws_bedrock_runtime_endpoint, str
791838
):
792839
proxy_endpoint_url = aws_bedrock_runtime_endpoint
840+
elif env_aws_bedrock_runtime_endpoint and isinstance(
841+
env_aws_bedrock_runtime_endpoint, str
842+
):
843+
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
793844
else:
794845
proxy_endpoint_url = endpoint_url
795846

0 commit comments

Comments
 (0)