Skip to content

Commit 237b9dc

Browse files
Merge pull request #30 from keycardai/revert-29-feat/2-step-eks-workload-identity
Revert "feat(keycardai-mcp): app credential grant flow"
2 parents 9422ee2 + 5f6acf5 commit 237b9dc

File tree

7 files changed

+6
-381
lines changed

7 files changed

+6
-381
lines changed

packages/mcp-fastmcp/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies = [
1212
"httpx>=0.27.2",
1313
"keycardai-oauth>=0.5.0",
1414
"fastmcp==2.12.0",
15-
"keycardai-mcp>=0.10.0",
15+
"keycardai-mcp>=0.11.0",
1616
]
1717
keywords = ["fastmcp", "mcp", "model-context-protocol", "oauth", "token-exchange", "authentication", "keycard"]
1818
classifiers = [

packages/mcp/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ dependencies = [
1313
"httpx>=0.27.2",
1414
"starlette>=0.47.3",
1515
"nanoid>=2.0.0",
16-
"pyjwt>=2.10.1",
1716
]
1817
keywords = ["mcp", "model-context-protocol", "authentication", "authorization", "ai", "llm"]
1918
classifiers = [

packages/mcp/src/keycardai/mcp/server/auth/_cache.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import time
55
from dataclasses import dataclass
6-
from typing import Any, Protocol
6+
from typing import Any
77

88

99
@dataclass
@@ -155,40 +155,3 @@ def cleanup_expired(self) -> int:
155155
del self._cache[cache_key]
156156

157157
return len(expired_keys)
158-
159-
class TokenCache(Protocol):
160-
"""Protocol for token cache implementations."""
161-
162-
def get(self, key: str) -> tuple[str, int] | None:
163-
"""Get a value from the cache if it exists and hasn't expired."""
164-
pass
165-
166-
def set(self, key: str, value: tuple[str, int]) -> None:
167-
"""Set a value in the cache with current timestamp."""
168-
pass
169-
170-
class InMemoryTokenCache(TokenCache):
171-
"""In-memory token cache implementation."""
172-
173-
def __init__(self, exp_leeway: int = 300):
174-
self.exp_leeway = exp_leeway
175-
self._cache: dict[str, tuple[str, int]] = {}
176-
177-
def get(self, key: str) -> tuple[str, int] | None:
178-
cached = self._cache.get(key)
179-
if not cached:
180-
return None
181-
182-
access_token, exp_time = cached
183-
# Check if token is expired with leeway (default 5 minutes before exp)
184-
# exp_time is epoch timestamp from JWT 'exp' claim
185-
current_time = int(time.time())
186-
if current_time >= (exp_time - self.exp_leeway):
187-
# Token expired or too close to expiration - force refresh
188-
self._cache.pop(key)
189-
return None
190-
191-
return cached
192-
193-
def set(self, key: str, value: tuple[str, int]) -> None:
194-
self._cache[key] = value

packages/mcp/src/keycardai/mcp/server/auth/application_credentials.py

Lines changed: 1 addition & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818

1919
import os
2020
import uuid
21-
from asyncio import Lock
2221
from typing import Protocol
2322

24-
from jwt import decode
25-
2623
from keycardai.oauth import (
2724
AsyncClient,
2825
AuthStrategy,
@@ -39,7 +36,6 @@
3936
EKSWorkloadIdentityConfigurationError,
4037
EKSWorkloadIdentityRuntimeError,
4138
)
42-
from ._cache import InMemoryTokenCache, TokenCache
4339
from .private_key import (
4440
FilePrivateKeyStorage,
4541
PrivateKeyManager,
@@ -432,7 +428,6 @@ def __init__(
432428
self,
433429
token_file_path: str | None = None,
434430
env_var_name: str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
435-
cache: TokenCache | None = None
436431
):
437432
"""Initialize EKS workload identity provider.
438433
@@ -441,16 +436,10 @@ def __init__(
441436
reads from the environment variable specified by env_var_name.
442437
env_var_name: Name of the environment variable containing the token file path.
443438
Defaults to AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE.
444-
cache_leeway: Number of seconds before JWT expiration to consider it expired.
445-
Defaults to 300 seconds (5 minutes) for safety margin.
446439
447440
Raises:
448441
EKSWorkloadIdentityConfigurationError: If token file cannot be read or is empty.
449442
"""
450-
if cache is None:
451-
cache = InMemoryTokenCache()
452-
self._application_credentials: TokenCache = cache # jti -> (access_token, exp)
453-
self._lock = Lock()
454443
self.env_var_name = env_var_name
455444

456445
if token_file_path is not None:
@@ -570,107 +559,6 @@ def set_client_config(
570559
"""
571560
return config
572561

573-
def _get_application_credential(self, client_assertion: str) -> tuple[str, int] | None:
574-
"""Get cached application credential if valid and not expired.
575-
576-
Uses double-checked locking pattern - this is the fast path that checks
577-
cache WITHOUT acquiring lock for maximum performance.
578-
579-
Args:
580-
client_assertion: The EKS workload identity JWT token
581-
582-
Returns:
583-
Tuple of (access_token, exp) if cached and valid, None if expired or not cached
584-
"""
585-
try:
586-
# Decode without verification - we just need the jti claim for cache lookup
587-
# The Keycard backend validates the actual assertion signature
588-
decoded_assertion = decode(client_assertion, options={"verify_signature": False})
589-
except Exception:
590-
return None
591-
592-
assertion_jti = decoded_assertion.get("jti")
593-
if not assertion_jti:
594-
return None
595-
596-
return self._application_credentials.get(assertion_jti)
597-
598-
def _set_application_credential(self, client_assertion: str, access_token: str) -> None:
599-
"""Cache the application credential with expiration from JWT.
600-
601-
Extracts the 'exp' claim from the access token JWT and stores it
602-
for expiration checking. This method is called inside the lock.
603-
604-
Args:
605-
client_assertion: The EKS workload identity JWT (used for cache key)
606-
access_token: The access token JWT (contains 'exp' claim)
607-
"""
608-
try:
609-
decoded_assertion = decode(client_assertion, options={"verify_signature": False})
610-
except Exception:
611-
return
612-
613-
assertion_jti = decoded_assertion.get("jti")
614-
if not assertion_jti:
615-
return
616-
617-
try:
618-
# Decode access token to get exp claim for expiration checking
619-
decoded_token = decode(access_token, options={"verify_signature": False})
620-
except Exception:
621-
return
622-
623-
exp_time = decoded_token.get("exp")
624-
if not exp_time:
625-
return
626-
627-
# Cache with epoch expiration time from JWT
628-
self._application_credentials.set(assertion_jti, (access_token, exp_time))
629-
630-
async def get_application_credential(self, client: AsyncClient, client_assertion: str) -> str:
631-
"""Get application credential.
632-
633-
Args:
634-
client: OAuth client for token exchange
635-
client_assertion: The EKS workload identity JWT token
636-
637-
Returns:
638-
The access token for the application credential
639-
640-
Raises:
641-
EKSWorkloadIdentityRuntimeError: If token exchange fails
642-
"""
643-
cached = self._get_application_credential(client_assertion)
644-
if cached:
645-
access_token, _ = cached
646-
return access_token
647-
648-
async with self._lock:
649-
# DOUBLE CHECK: Another coroutine may have just populated cache while we waited
650-
cached = self._get_application_credential(client_assertion)
651-
if cached:
652-
access_token, _ = cached
653-
return access_token
654-
655-
request = TokenExchangeRequest(
656-
grant_type=GrantType.CLIENT_CREDENTIALS,
657-
client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION,
658-
client_assertion=client_assertion
659-
)
660-
661-
try:
662-
response = await client.exchange_token(request)
663-
except Exception as e:
664-
raise EKSWorkloadIdentityRuntimeError(
665-
token_file_path=self.token_file_path,
666-
env_var_name=self.env_var_name,
667-
error_details=f"Error getting application credential: {str(e)}",
668-
) from e
669-
670-
self._set_application_credential(client_assertion, response.access_token)
671-
return response.access_token
672-
673-
674562
async def prepare_token_exchange_request(
675563
self,
676564
client: AsyncClient,
@@ -699,13 +587,11 @@ async def prepare_token_exchange_request(
699587
# Read the token from the filesystem
700588
eks_token = self._read_token()
701589

702-
application_credential = await self.get_application_credential(client, eks_token)
703-
704590
return TokenExchangeRequest(
705591
subject_token=subject_token,
706592
resource=resource,
707593
subject_token_type="urn:ietf:params:oauth:token-type:access_token",
708594
client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION,
709-
client_assertion=application_credential,
595+
client_assertion=eks_token,
710596
)
711597

0 commit comments

Comments
 (0)