Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Example (sanitized):
{
"access_config": "s3://<bucket>/access.json",
"token_provider_url": "https://<token-ui.example>",
"token_public_key_url": "https://<token-api.example>/public-key",
"token_public_keys_url": "https://<token-api.example>/token/public-keys",
"kafka_bootstrap_server": "broker1:9092,broker2:9092",
"event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus"
}
Expand Down Expand Up @@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build.
| Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) |

## Security & Authorization
- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`).
- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`).
- Subject claim (`sub`) is matched against `ACCESS[topicName]`.
- Authorization header forms accepted:
- `Authorization: Bearer <token>` (preferred)
Expand Down
2 changes: 1 addition & 1 deletion conf/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"access_config": "s3://<redacted>/access.json",
"token_provider_url": "https://<redacted>",
"token_public_key_url": "https://<redacted>",
"token_public_keys_url": "https://<redacted>",
"kafka_bootstrap_server": "localhost:9092",
"event_bus_arn": "arn:aws:events:<redacted>"
}
47 changes: 39 additions & 8 deletions src/event_gate_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import urllib3
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from jsonschema import validate
from jsonschema.exceptions import ValidationError

Expand Down Expand Up @@ -80,17 +81,33 @@
ACCESS = json.load(file)
logger.debug("Loaded ACCESS definitions")

TOKEN_PROVIDER_URL = CONFIG["token_provider_url"]
# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit
# Initialize token public keys
TOKEN_PROVIDER_URL = CONFIG.get("token_provider_url")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put the token handling to separate module pls

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was change has a bigger impact, please see it in commit: d23819c. The project is now better readable. Need a thorough revision please.

TOKEN_PUBLIC_KEYS_URL = CONFIG.get("token_public_keys_url") or CONFIG.get("token_public_key_url")

try:
response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external
token_public_key_encoded = response_json["key"]
TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded))
logger.debug("Loaded TOKEN_PUBLIC_KEY")
response_json = requests.get(TOKEN_PUBLIC_KEYS_URL, verify=False, timeout=5).json()
raw_keys: list[str] = []
if isinstance(response_json, dict):
if "keys" in response_json and isinstance(response_json["keys"], list):
for item in response_json["keys"]:
if "key" in item:
raw_keys.append(item["key"].strip())
elif "key" in response_json:
raw_keys.append(response_json["key"].strip())

if not raw_keys:
raise KeyError(f"No public keys found in {TOKEN_PUBLIC_KEYS_URL} endpoint response")

TOKEN_PUBLIC_KEYS: list[RSAPublicKey] = [
serialization.load_der_public_key(base64.b64decode(raw_key)) for raw_key in raw_keys
]
logger.debug("Loaded %d TOKEN_PUBLIC_KEYS", len(TOKEN_PUBLIC_KEYS))
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url"))
logger.exception("Failed to fetch or deserialize token public key from %s", TOKEN_PUBLIC_KEYS_URL)
raise RuntimeError("Token public key initialization failed") from exc

# Initialize EventGate writers
writer_eventbridge.init(logger, CONFIG)
writer_kafka.init(logger, CONFIG)
writer_postgres.init(logger)
Expand Down Expand Up @@ -163,7 +180,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
"""
logger.debug("Handling POST %s", topic_name)
try:
token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type]
token = decode_jwt_all(token_encoded)
except jwt.PyJWTError: # type: ignore[attr-defined]
return _error_response(401, "auth", "Invalid or missing token")

Expand Down Expand Up @@ -205,6 +222,20 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
}


def decode_jwt_all(token_encoded: str) -> Dict[str, Any]:
"""Decode JWT using any of the loaded public keys.

Args:
token_encoded: Encoded bearer JWT token string.
"""
for public_key in TOKEN_PUBLIC_KEYS:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
except jwt.PyJWTError:
continue
raise jwt.PyJWTError("Verification failed for all public keys")


def extract_token(event_headers: Dict[str, str]) -> str:
"""Extract bearer token from headers (case-insensitive).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from glob import glob
import pytest

CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "conf")
CONF_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "conf")

REQUIRED_CONFIG_KEYS = {
"access_config",
"token_provider_url",
"token_public_key_url",
"token_public_keys_url",
"kafka_bootstrap_server",
"event_bus_arn",
}
Expand Down
51 changes: 51 additions & 0 deletions tests/test_event_gate_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,54 @@ def test_post_invalid_json_body(event_gate_module, make_event):
assert resp["statusCode"] == 500
body = json.loads(resp["body"])
assert any(e["type"] == "internal" for e in body["errors"]) # internal error path


def test_post_expired_token(event_gate_module, make_event, valid_payload):
"""Expired JWT should yield 401 auth error."""
with patch.object(
event_gate_module.jwt,
"decode",
side_effect=event_gate_module.jwt.ExpiredSignatureError("expired"),
create=True,
):
event = make_event(
"/topics/{topic_name}",
method="POST",
topic="public.cps.za.test",
body=valid_payload,
headers={"Authorization": "Bearer expiredtoken"},
)
resp = event_gate_module.lambda_handler(event, None)
assert resp["statusCode"] == 401
body = json.loads(resp["body"])
assert any(e["type"] == "auth" for e in body["errors"])


def test_decode_jwt_all_second_key_succeeds(event_gate_module):
"""First key fails signature, second key succeeds; claims returned from second key."""
first_key = object()
second_key = object()
event_gate_module.TOKEN_PUBLIC_KEYS = [first_key, second_key]

def decode_side_effect(token, key, algorithms):
if key is first_key:
raise event_gate_module.jwt.PyJWTError("signature mismatch")
return {"sub": "TestUser"}

with patch.object(event_gate_module.jwt, "decode", side_effect=decode_side_effect, create=True):
claims = event_gate_module.decode_jwt_all("dummy-token")
assert claims["sub"] == "TestUser"


def test_decode_jwt_all_all_keys_fail(event_gate_module):
"""All keys fail; final PyJWTError with aggregate message is raised."""
bad_keys = [object(), object()]
event_gate_module.TOKEN_PUBLIC_KEYS = bad_keys

def always_fail(token, key, algorithms):
raise event_gate_module.jwt.PyJWTError("bad signature")

with patch.object(event_gate_module.jwt, "decode", side_effect=always_fail, create=True):
with pytest.raises(event_gate_module.jwt.PyJWTError) as exc:
event_gate_module.decode_jwt_all("dummy-token")
assert "Verification failed for all public keys" in str(exc.value)
Loading