Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Example (sanitized):
{
"access_config": "s3://<bucket>/access.json",
"token_provider_url": "https://<token-ui.example>",
"token_public_key_url": "https://<token-api.example>/public-key",
"token_public_keys_url": "https://<token-api.example>/token/public-keys",
"kafka_bootstrap_server": "broker1:9092,broker2:9092",
"event_bus_arn": "arn:aws:events:region:acct:event-bus/your-bus"
}
Expand Down Expand Up @@ -137,7 +137,7 @@ Use when Kafka access needs Kerberos / SASL_SSL or custom `librdkafka` build.
| Code coverage | [Code Coverage](./DEVELOPER.md#code-coverage) |

## Security & Authorization
- JWT tokens must be RS256 signed; the public key is fetched at cold start from `token_public_key_url` (DER base64 inside JSON `{ "key": "..." }`).
- JWT tokens must be RS256 signed; current and previous public keys are fetched at cold start from `token_public_keys_url` as DER base64 values (list `keys[*].key`, with single-key fallback `{ "key": "..." }`).
- Subject claim (`sub`) is matched against `ACCESS[topicName]`.
- Authorization header forms accepted:
- `Authorization: Bearer <token>` (preferred)
Expand Down
2 changes: 1 addition & 1 deletion conf/config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"access_config": "s3://<redacted>/access.json",
"token_provider_url": "https://<redacted>",
"token_public_key_url": "https://<redacted>",
"token_public_keys_url": "https://<redacted>",
"kafka_bootstrap_server": "localhost:9092",
"event_bus_arn": "arn:aws:events:<redacted>"
}
79 changes: 14 additions & 65 deletions src/event_gate_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#

"""Event Gate Lambda function implementation."""
import base64
import json
import logging
import os
Expand All @@ -24,13 +23,11 @@

import boto3
import jwt
import requests
import urllib3
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from jsonschema import validate
from jsonschema.exceptions import ValidationError

from src.handlers.handler_token import HandlerToken
from src.writers import writer_eventbridge, writer_kafka, writer_postgres
from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV

Expand Down Expand Up @@ -64,35 +61,28 @@
logger.debug("Loaded TOPICS")

with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file:
CONFIG = json.load(file)
config = json.load(file)
logger.debug("Loaded main CONFIG")

aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally
logger.debug("Initialized AWS S3 Client")

if CONFIG["access_config"].startswith("s3://"):
name_parts = CONFIG["access_config"].split("/")
if config["access_config"].startswith("s3://"):
name_parts = config["access_config"].split("/")
BUCKET_NAME = name_parts[2]
BUCKET_OBJECT_KEY = "/".join(name_parts[3:])
ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8"))
else:
with open(CONFIG["access_config"], "r", encoding="utf-8") as file:
with open(config["access_config"], "r", encoding="utf-8") as file:
ACCESS = json.load(file)
logger.debug("Loaded ACCESS definitions")

TOKEN_PROVIDER_URL = CONFIG["token_provider_url"]
# Add timeout to avoid hanging requests; wrap in robust error handling so failures are explicit
try:
response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external
token_public_key_encoded = response_json["key"]
TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded))
logger.debug("Loaded TOKEN_PUBLIC_KEY")
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
logger.exception("Failed to fetch or deserialize token public key from %s", CONFIG.get("token_public_key_url"))
raise RuntimeError("Token public key initialization failed") from exc

writer_eventbridge.init(logger, CONFIG)
writer_kafka.init(logger, CONFIG)
# Initialize token handler and load token public keys
handler_token = HandlerToken(config).load_public_keys()

# Initialize EventGate writers
writer_eventbridge.init(logger, config)
writer_kafka.init(logger, config)
writer_postgres.init(logger)


Expand Down Expand Up @@ -124,12 +114,6 @@ def get_api() -> Dict[str, Any]:
return {"statusCode": 200, "body": API}


def get_token() -> Dict[str, Any]:
"""Return 303 redirect to token provider endpoint."""
logger.debug("Handling GET Token")
return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}}


def get_topics() -> Dict[str, Any]:
"""Return list of available topic names."""
logger.debug("Handling GET Topics")
Expand Down Expand Up @@ -163,7 +147,7 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
"""
logger.debug("Handling POST %s", topic_name)
try:
token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type]
token: Dict[str, Any] = handler_token.decode_jwt(token_encoded)
except jwt.PyJWTError: # type: ignore[attr-defined]
return _error_response(401, "auth", "Invalid or missing token")

Expand Down Expand Up @@ -205,41 +189,6 @@ def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_enc
}


def extract_token(event_headers: Dict[str, str]) -> str:
"""Extract bearer token from headers (case-insensitive).

Supports:
- Custom 'bearer' header (any casing) whose value is the raw token
- Standard 'Authorization: Bearer <token>' header (case-insensitive scheme & key)
Returns empty string if token not found or malformed.
"""
if not event_headers:
return ""

# Normalize keys to lowercase for case-insensitive lookup
lowered = {str(k).lower(): v for k, v in event_headers.items()}

# Direct bearer header (raw token)
if "bearer" in lowered and isinstance(lowered["bearer"], str):
token_candidate = lowered["bearer"].strip()
if token_candidate:
return token_candidate

# Authorization header with Bearer scheme
auth_val = lowered.get("authorization", "")
if not isinstance(auth_val, str): # defensive
return ""
auth_val = auth_val.strip()
if not auth_val:
return ""

# Case-insensitive match for 'Bearer ' prefix
if not auth_val.lower().startswith("bearer "):
return ""
token_part = auth_val[7:].strip() # len('Bearer ')==7
return token_part


def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements
"""AWS Lambda entry point.

Expand All @@ -250,7 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
if resource == "/api":
return get_api()
if resource == "/token":
return get_token()
return handler_token.get_token_provider_info()
if resource == "/topics":
return get_topics()
if resource == "/topics/{topic_name}":
Expand All @@ -261,7 +210,7 @@ def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unus
return post_topic_message(
event["pathParameters"]["topic_name"].lower(),
json.loads(event["body"]),
extract_token(event.get("headers", {})),
handler_token.extract_token(event.get("headers", {})),
)
if resource == "/terminate":
sys.exit("TERMINATING") # pragma: no cover - deliberate termination path
Expand Down
15 changes: 15 additions & 0 deletions src/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
167 changes: 167 additions & 0 deletions src/handlers/handler_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides the HandlerToken class for managing the token related operations.
"""

import base64
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import Dict, Any, cast

import jwt
import requests
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey

from src.utils.constants import TOKEN_PROVIDER_URL_KEY, TOKEN_PUBLIC_KEYS_URL_KEY, TOKEN_PUBLIC_KEY_URL_KEY

logger = logging.getLogger(__name__)
log_level = os.environ.get("LOG_LEVEL", "INFO")
logger.setLevel(log_level)


class HandlerToken:
"""
HandlerToken manages token provider URL and public keys for JWT verification.
"""

_REFRESH_INTERVAL = timedelta(minutes=28)

def __init__(self, config):
self.provider_url: str = config.get(TOKEN_PROVIDER_URL_KEY, "")
self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL_KEY) or config.get(TOKEN_PUBLIC_KEY_URL_KEY)
self.public_keys: list[RSAPublicKey] = []
self._last_loaded_at: datetime | None = None

def _refresh_keys_if_needed(self) -> None:
"""
Refresh the public keys if the refresh interval has passed.
"""
logger.debug("Checking if the token public keys need refresh")

if self._last_loaded_at is None:
return
now = datetime.now(timezone.utc)
if now - self._last_loaded_at < self._REFRESH_INTERVAL:
logger.debug("Token public keys are up to date, no refresh needed")
return
try:
logger.debug("Token public keys are stale, refreshing now")
self.load_public_keys()
except RuntimeError:
logger.warning("Token public key refresh failed, using existing keys")

def load_public_keys(self) -> "HandlerToken":
"""
Load token public keys from the configured URL.
Returns:
HandlerToken: The current instance with loaded public keys.
Raises:
RuntimeError: If fetching or deserializing the public keys fails.
"""
logger.debug("Loading token public keys from %s", self.public_keys_url)

try:
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
Comment on lines +81 to +82
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Security concern: SSL certificate verification is disabled.

Using verify=False disables SSL certificate verification, making this request vulnerable to man-in-the-middle attacks. This is particularly concerning when fetching cryptographic public keys, as an attacker could inject malicious keys.

Consider making SSL verification configurable or defaulting to verify=True:

-            response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
+            response_json = requests.get(self.public_keys_url, timeout=5).json()

If there's a legitimate need to disable verification in certain environments (e.g., development with self-signed certs), consider making it configurable via an environment variable or config parameter rather than unconditionally disabling it.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json()
try:
response_json = requests.get(self.public_keys_url, timeout=5).json()
🧰 Tools
🪛 Ruff (0.14.5)

60-60: Probable use of requests call with verify=False disabling SSL certificate checks

(S501)

🤖 Prompt for AI Agents
In src/handlers/handler_token.py around lines 59-60, the requests.get call
disables SSL certificate verification (verify=False) which is insecure; change
it to use a configurable verification flag (e.g., read from an environment
variable or config with default True) and pass that flag to requests.get
(verify=<config_flag>) instead of False, ensure the default is True, add a log
warning if verification is explicitly disabled, and do not leave verify=False
hard-coded in the repository.

raw_keys: list[str] = []

if isinstance(response_json, dict):
if "keys" in response_json and isinstance(response_json["keys"], list):
for item in response_json["keys"]:
if "key" in item:
raw_keys.append(item["key"].strip())
elif "key" in response_json:
raw_keys.append(response_json["key"].strip())

if not raw_keys:
raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response")

self.public_keys = [
cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys
]
logger.debug("Loaded %d token public keys", len(self.public_keys))
self._last_loaded_at = datetime.now(timezone.utc)

return self
except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc:
logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url)
raise RuntimeError("Token public key initialization failed") from exc

def decode_jwt(self, token_encoded: str) -> Dict[str, Any]:
"""
Decode and verify a JWT using the loaded public keys.
Args:
token_encoded (str): The encoded JWT token.
Returns:
Dict[str, Any]: The decoded JWT payload.
Raises:
jwt.PyJWTError: If verification fails for all public keys.
"""
self._refresh_keys_if_needed()

logger.debug("Decoding JWT")
for public_key in self.public_keys:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
except jwt.PyJWTError:
continue
raise jwt.PyJWTError("Verification failed for all public keys")
Comment on lines +120 to +125
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Overly broad exception handling catches non-signature errors.

Catching all PyJWTError means expired tokens (ExpiredSignatureError) and malformed tokens (DecodeError) will unnecessarily try every key before failing. Only signature-related errors should trigger trying the next key.

     for public_key in self.public_keys:
         try:
             return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
-        except jwt.PyJWTError:
+        except jwt.InvalidSignatureError:
             continue
+        except jwt.PyJWTError:
+            raise  # Expired, malformed, etc. - fail immediately
     raise jwt.PyJWTError("Verification failed for all public keys")

This ensures expired or malformed tokens fail fast with the correct error, while signature mismatches correctly try the next key in the rotation.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for public_key in self.public_keys:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
except jwt.PyJWTError:
continue
raise jwt.PyJWTError("Verification failed for all public keys")
for public_key in self.public_keys:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
except jwt.InvalidSignatureError:
continue
except jwt.PyJWTError:
raise # Expired, malformed, etc. - fail immediately
raise jwt.PyJWTError("Verification failed for all public keys")
🧰 Tools
🪛 Ruff (0.14.6)

125-125: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In src/handlers/handler_token.py around lines 120-125, the loop currently
catches all jwt.PyJWTError which masks expired or malformed token errors; change
the logic to only catch signature-related exceptions (e.g.,
jwt.InvalidSignatureError or jwt.exceptions.InvalidSignatureError) inside the
loop so only signature mismatches try the next public key, and let other
PyJWTError subclasses (ExpiredSignatureError, DecodeError, etc.) propagate
immediately; implement this by catching InvalidSignatureError in the except
branch (or store the last signature error and continue) and re-raising any other
PyJWTError after the loop (or re-raise immediately when caught).


def get_token_provider_info(self) -> Dict[str, Any]:
"""
Returns: A 303 redirect response to the token provider URL.
"""
logger.debug("Handling GET Token")
return {"statusCode": 303, "headers": {"Location": self.provider_url}}

@staticmethod
def extract_token(event_headers: Dict[str, str]) -> str:
"""
Extracts the bearer (custom/standard) token from event headers.
Args:
event_headers (Dict[str, str]): The event headers.
Returns:
str: The extracted bearer token, or an empty string if not found.
"""
if not event_headers:
return ""

# Normalize keys to lowercase for case-insensitive lookup
lowered = {str(k).lower(): v for k, v in event_headers.items()}

# Direct bearer header (raw token)
if "bearer" in lowered and isinstance(lowered["bearer"], str):
token_candidate = lowered["bearer"].strip()
if token_candidate:
return token_candidate

# Authorization header with Bearer scheme
auth_val = lowered.get("authorization", "")
if not isinstance(auth_val, str): # defensive
return ""
auth_val = auth_val.strip()
if not auth_val:
return ""

# Case-insensitive match for 'Bearer ' prefix
if not auth_val.lower().startswith("bearer "):
return ""
token_part = auth_val[7:].strip() # len('Bearer ')==7
return token_part
24 changes: 24 additions & 0 deletions src/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright 2025 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module contains all constants and enums used across the project.
"""

# Token related configuration keys
TOKEN_PROVIDER_URL_KEY = "token_provider_url"
TOKEN_PUBLIC_KEY_URL_KEY = "token_public_key_url"
TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url"
Loading