-
Notifications
You must be signed in to change notification settings - Fork 0
Token verification with previous public token #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fefc931
e00a6f7
d23819c
a0295d4
788a90a
692b309
ae60d1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| { | ||
| "access_config": "s3://<redacted>/access.json", | ||
| "token_provider_url": "https://<redacted>", | ||
| "token_public_key_url": "https://<redacted>", | ||
| "token_public_keys_url": "https://<redacted>", | ||
| "kafka_bootstrap_server": "localhost:9092", | ||
| "event_bus_arn": "arn:aws:events:<redacted>" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # | ||
| # Copyright 2025 ABSA Group Limited | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,167 @@ | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # Copyright 2025 ABSA Group Limited | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| This module provides the HandlerToken class for managing the token related operations. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import base64 | ||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||
| from datetime import datetime, timedelta, timezone | ||||||||||||||||||||||||||||||
| from typing import Dict, Any, cast | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import jwt | ||||||||||||||||||||||||||||||
| import requests | ||||||||||||||||||||||||||||||
| from cryptography.exceptions import UnsupportedAlgorithm | ||||||||||||||||||||||||||||||
| from cryptography.hazmat.primitives import serialization | ||||||||||||||||||||||||||||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from src.utils.constants import TOKEN_PROVIDER_URL_KEY, TOKEN_PUBLIC_KEYS_URL_KEY, TOKEN_PUBLIC_KEY_URL_KEY | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||
| log_level = os.environ.get("LOG_LEVEL", "INFO") | ||||||||||||||||||||||||||||||
| logger.setLevel(log_level) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class HandlerToken: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| HandlerToken manages token provider URL and public keys for JWT verification. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| _REFRESH_INTERVAL = timedelta(minutes=28) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def __init__(self, config): | ||||||||||||||||||||||||||||||
| self.provider_url: str = config.get(TOKEN_PROVIDER_URL_KEY, "") | ||||||||||||||||||||||||||||||
| self.public_keys_url: str = config.get(TOKEN_PUBLIC_KEYS_URL_KEY) or config.get(TOKEN_PUBLIC_KEY_URL_KEY) | ||||||||||||||||||||||||||||||
| self.public_keys: list[RSAPublicKey] = [] | ||||||||||||||||||||||||||||||
| self._last_loaded_at: datetime | None = None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _refresh_keys_if_needed(self) -> None: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Refresh the public keys if the refresh interval has passed. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| logger.debug("Checking if the token public keys need refresh") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if self._last_loaded_at is None: | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| now = datetime.now(timezone.utc) | ||||||||||||||||||||||||||||||
| if now - self._last_loaded_at < self._REFRESH_INTERVAL: | ||||||||||||||||||||||||||||||
| logger.debug("Token public keys are up to date, no refresh needed") | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| logger.debug("Token public keys are stale, refreshing now") | ||||||||||||||||||||||||||||||
| self.load_public_keys() | ||||||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||||||
| logger.warning("Token public key refresh failed, using existing keys") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def load_public_keys(self) -> "HandlerToken": | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Load token public keys from the configured URL. | ||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| HandlerToken: The current instance with loaded public keys. | ||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||
| RuntimeError: If fetching or deserializing the public keys fails. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| logger.debug("Loading token public keys from %s", self.public_keys_url) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| response_json = requests.get(self.public_keys_url, verify=False, timeout=5).json() | ||||||||||||||||||||||||||||||
| raw_keys: list[str] = [] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if isinstance(response_json, dict): | ||||||||||||||||||||||||||||||
| if "keys" in response_json and isinstance(response_json["keys"], list): | ||||||||||||||||||||||||||||||
| for item in response_json["keys"]: | ||||||||||||||||||||||||||||||
| if "key" in item: | ||||||||||||||||||||||||||||||
| raw_keys.append(item["key"].strip()) | ||||||||||||||||||||||||||||||
| elif "key" in response_json: | ||||||||||||||||||||||||||||||
| raw_keys.append(response_json["key"].strip()) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if not raw_keys: | ||||||||||||||||||||||||||||||
| raise KeyError(f"No public keys found in {self.public_keys_url} endpoint response") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| self.public_keys = [ | ||||||||||||||||||||||||||||||
| cast(RSAPublicKey, serialization.load_der_public_key(base64.b64decode(raw_key))) for raw_key in raw_keys | ||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||
| logger.debug("Loaded %d token public keys", len(self.public_keys)) | ||||||||||||||||||||||||||||||
| self._last_loaded_at = datetime.now(timezone.utc) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||||
| except (requests.RequestException, ValueError, KeyError, UnsupportedAlgorithm) as exc: | ||||||||||||||||||||||||||||||
| logger.exception("Failed to fetch or deserialize token public key from %s", self.public_keys_url) | ||||||||||||||||||||||||||||||
| raise RuntimeError("Token public key initialization failed") from exc | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def decode_jwt(self, token_encoded: str) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Decode and verify a JWT using the loaded public keys. | ||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| token_encoded (str): The encoded JWT token. | ||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| Dict[str, Any]: The decoded JWT payload. | ||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||
| jwt.PyJWTError: If verification fails for all public keys. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| self._refresh_keys_if_needed() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| logger.debug("Decoding JWT") | ||||||||||||||||||||||||||||||
| for public_key in self.public_keys: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| return jwt.decode(token_encoded, public_key, algorithms=["RS256"]) | ||||||||||||||||||||||||||||||
| except jwt.PyJWTError: | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| raise jwt.PyJWTError("Verification failed for all public keys") | ||||||||||||||||||||||||||||||
|
Comment on lines
+120
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overly broad exception handling catches non-signature errors. Catching all for public_key in self.public_keys:
try:
return jwt.decode(token_encoded, public_key, algorithms=["RS256"])
- except jwt.PyJWTError:
+ except jwt.InvalidSignatureError:
continue
+ except jwt.PyJWTError:
+ raise # Expired, malformed, etc. - fail immediately
raise jwt.PyJWTError("Verification failed for all public keys")This ensures expired or malformed tokens fail fast with the correct error, while signature mismatches correctly try the next key in the rotation. 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.6)125-125: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def get_token_provider_info(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Returns: A 303 redirect response to the token provider URL. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| logger.debug("Handling GET Token") | ||||||||||||||||||||||||||||||
| return {"statusCode": 303, "headers": {"Location": self.provider_url}} | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||
| def extract_token(event_headers: Dict[str, str]) -> str: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Extracts the bearer (custom/standard) token from event headers. | ||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| event_headers (Dict[str, str]): The event headers. | ||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| str: The extracted bearer token, or an empty string if not found. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| if not event_headers: | ||||||||||||||||||||||||||||||
| return "" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Normalize keys to lowercase for case-insensitive lookup | ||||||||||||||||||||||||||||||
| lowered = {str(k).lower(): v for k, v in event_headers.items()} | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Direct bearer header (raw token) | ||||||||||||||||||||||||||||||
| if "bearer" in lowered and isinstance(lowered["bearer"], str): | ||||||||||||||||||||||||||||||
| token_candidate = lowered["bearer"].strip() | ||||||||||||||||||||||||||||||
| if token_candidate: | ||||||||||||||||||||||||||||||
| return token_candidate | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Authorization header with Bearer scheme | ||||||||||||||||||||||||||||||
| auth_val = lowered.get("authorization", "") | ||||||||||||||||||||||||||||||
| if not isinstance(auth_val, str): # defensive | ||||||||||||||||||||||||||||||
| return "" | ||||||||||||||||||||||||||||||
| auth_val = auth_val.strip() | ||||||||||||||||||||||||||||||
| if not auth_val: | ||||||||||||||||||||||||||||||
| return "" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Case-insensitive match for 'Bearer ' prefix | ||||||||||||||||||||||||||||||
| if not auth_val.lower().startswith("bearer "): | ||||||||||||||||||||||||||||||
| return "" | ||||||||||||||||||||||||||||||
| token_part = auth_val[7:].strip() # len('Bearer ')==7 | ||||||||||||||||||||||||||||||
| return token_part | ||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # | ||
| # Copyright 2025 ABSA Group Limited | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| """ | ||
| This module contains all constants and enums used across the project. | ||
| """ | ||
|
|
||
| # Token related configuration keys | ||
| TOKEN_PROVIDER_URL_KEY = "token_provider_url" | ||
| TOKEN_PUBLIC_KEY_URL_KEY = "token_public_key_url" | ||
| TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Security concern: SSL certificate verification is disabled.
Using
verify=Falsedisables SSL certificate verification, making this request vulnerable to man-in-the-middle attacks. This is particularly concerning when fetching cryptographic public keys, as an attacker could inject malicious keys.Consider making SSL verification configurable or defaulting to
verify=True:If there's a legitimate need to disable verification in certain environments (e.g., development with self-signed certs), consider making it configurable via an environment variable or config parameter rather than unconditionally disabling it.
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.14.5)
60-60: Probable use of
requestscall withverify=Falsedisabling SSL certificate checks(S501)
🤖 Prompt for AI Agents