1818
1919import os
2020import uuid
21- from asyncio import Lock
2221from typing import Protocol
2322
24- from jwt import decode
25-
2623from keycardai .oauth import (
2724 AsyncClient ,
2825 AuthStrategy ,
3936 EKSWorkloadIdentityConfigurationError ,
4037 EKSWorkloadIdentityRuntimeError ,
4138)
42- from ._cache import InMemoryTokenCache , TokenCache
4339from .private_key import (
4440 FilePrivateKeyStorage ,
4541 PrivateKeyManager ,
@@ -432,7 +428,6 @@ def __init__(
432428 self ,
433429 token_file_path : str | None = None ,
434430 env_var_name : str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" ,
435- cache : TokenCache | None = None
436431 ):
437432 """Initialize EKS workload identity provider.
438433
@@ -441,16 +436,10 @@ def __init__(
441436 reads from the environment variable specified by env_var_name.
442437 env_var_name: Name of the environment variable containing the token file path.
443438 Defaults to AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE.
444- cache_leeway: Number of seconds before JWT expiration to consider it expired.
445- Defaults to 300 seconds (5 minutes) for safety margin.
446439
447440 Raises:
448441 EKSWorkloadIdentityConfigurationError: If token file cannot be read or is empty.
449442 """
450- if cache is None :
451- cache = InMemoryTokenCache ()
452- self ._application_credentials : TokenCache = cache # jti -> (access_token, exp)
453- self ._lock = Lock ()
454443 self .env_var_name = env_var_name
455444
456445 if token_file_path is not None :
@@ -570,107 +559,6 @@ def set_client_config(
570559 """
571560 return config
572561
573- def _get_application_credential (self , client_assertion : str ) -> tuple [str , int ] | None :
574- """Get cached application credential if valid and not expired.
575-
576- Uses double-checked locking pattern - this is the fast path that checks
577- cache WITHOUT acquiring lock for maximum performance.
578-
579- Args:
580- client_assertion: The EKS workload identity JWT token
581-
582- Returns:
583- Tuple of (access_token, exp) if cached and valid, None if expired or not cached
584- """
585- try :
586- # Decode without verification - we just need the jti claim for cache lookup
587- # The Keycard backend validates the actual assertion signature
588- decoded_assertion = decode (client_assertion , options = {"verify_signature" : False })
589- except Exception :
590- return None
591-
592- assertion_jti = decoded_assertion .get ("jti" )
593- if not assertion_jti :
594- return None
595-
596- return self ._application_credentials .get (assertion_jti )
597-
598- def _set_application_credential (self , client_assertion : str , access_token : str ) -> None :
599- """Cache the application credential with expiration from JWT.
600-
601- Extracts the 'exp' claim from the access token JWT and stores it
602- for expiration checking. This method is called inside the lock.
603-
604- Args:
605- client_assertion: The EKS workload identity JWT (used for cache key)
606- access_token: The access token JWT (contains 'exp' claim)
607- """
608- try :
609- decoded_assertion = decode (client_assertion , options = {"verify_signature" : False })
610- except Exception :
611- return
612-
613- assertion_jti = decoded_assertion .get ("jti" )
614- if not assertion_jti :
615- return
616-
617- try :
618- # Decode access token to get exp claim for expiration checking
619- decoded_token = decode (access_token , options = {"verify_signature" : False })
620- except Exception :
621- return
622-
623- exp_time = decoded_token .get ("exp" )
624- if not exp_time :
625- return
626-
627- # Cache with epoch expiration time from JWT
628- self ._application_credentials .set (assertion_jti , (access_token , exp_time ))
629-
630- async def get_application_credential (self , client : AsyncClient , client_assertion : str ) -> str :
631- """Get application credential.
632-
633- Args:
634- client: OAuth client for token exchange
635- client_assertion: The EKS workload identity JWT token
636-
637- Returns:
638- The access token for the application credential
639-
640- Raises:
641- EKSWorkloadIdentityRuntimeError: If token exchange fails
642- """
643- cached = self ._get_application_credential (client_assertion )
644- if cached :
645- access_token , _ = cached
646- return access_token
647-
648- async with self ._lock :
649- # DOUBLE CHECK: Another coroutine may have just populated cache while we waited
650- cached = self ._get_application_credential (client_assertion )
651- if cached :
652- access_token , _ = cached
653- return access_token
654-
655- request = TokenExchangeRequest (
656- grant_type = GrantType .CLIENT_CREDENTIALS ,
657- client_assertion_type = GrantType .JWT_BEARER_CLIENT_ASSERTION ,
658- client_assertion = client_assertion
659- )
660-
661- try :
662- response = await client .exchange_token (request )
663- except Exception as e :
664- raise EKSWorkloadIdentityRuntimeError (
665- token_file_path = self .token_file_path ,
666- env_var_name = self .env_var_name ,
667- error_details = f"Error getting application credential: { str (e )} " ,
668- ) from e
669-
670- self ._set_application_credential (client_assertion , response .access_token )
671- return response .access_token
672-
673-
674562 async def prepare_token_exchange_request (
675563 self ,
676564 client : AsyncClient ,
@@ -699,13 +587,11 @@ async def prepare_token_exchange_request(
699587 # Read the token from the filesystem
700588 eks_token = self ._read_token ()
701589
702- application_credential = await self .get_application_credential (client , eks_token )
703-
704590 return TokenExchangeRequest (
705591 subject_token = subject_token ,
706592 resource = resource ,
707593 subject_token_type = "urn:ietf:params:oauth:token-type:access_token" ,
708594 client_assertion_type = GrantType .JWT_BEARER_CLIENT_ASSERTION ,
709- client_assertion = application_credential ,
595+ client_assertion = eks_token ,
710596 )
711597
0 commit comments