Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@
patched_frontend_url = "localhost" if cookie_domain is None else f"https://{cookie_domain}"

with _temporary_frontend_url(patched_frontend_url):
return original(response)
result = original(response)

result.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
result.headers["Pragma"] = "no-cache"
return result

authentication_controller._set_new_access_token_in_cookie = patched_set_new_access_token_in_cookie
_COOKIE_DOMAIN_PATCHED = True
Expand Down Expand Up @@ -192,7 +196,7 @@
g.m8flow_tenant_id = previous


def apply_refresh_token_tenant_patch() -> None:

Check failure on line 199 in extensions/m8flow-backend/src/m8flow_backend/routes/authentication_controller_patch.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 23 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=AOT-Technologies_m8flow&issues=AZ0E6SJA1jy-UnSmeFH2&open=AZ0E6SJA1jy-UnSmeFH2&pullRequest=87
"""
Ensure refresh-token operations have tenant context during auth controller
flows that run before tenant-resolution hooks.
Expand All @@ -206,9 +210,29 @@

@wraps(original_login_return)
def patched_login_return(*args, **kwargs):
from flask import redirect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move inline imports to module level

from spiffworkflow_backend.services.authentication_service import AuthenticationService

state = kwargs.get("state")
if state is None and args:
state = args[0]

error = kwargs.get("error")
error_description = kwargs.get("error_description")
if error and error_description and "authentication_expired" in str(error_description):
try:
state_dict = ast.literal_eval(base64.b64decode(state).decode("utf-8"))
auth_id = state_dict.get("authentication_identifier")
final_url = state_dict.get("final_url", "/")
if auth_id:
login_url = AuthenticationService().get_login_redirect_url(
authentication_identifier=auth_id, final_url=final_url
)
logger.info("authentication_expired detected, retrying login for identifier=%s", auth_id)
return redirect(login_url)
except Exception:
logger.warning("Failed to auto-retry login after authentication_expired", exc_info=True)

tenant_id = _tenant_for_refresh_tokens(state=state if isinstance(state, str) else None)
auth_identifier = _authentication_identifier_from_state() or (
_decode_state_authentication_identifier(state) if isinstance(state, str) else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import ast
import base64
from contextlib import contextmanager
from functools import wraps
import json
import logging
import time
from urllib.parse import urlparse, urlunparse

import requests
Expand All @@ -18,17 +21,25 @@
AuthenticationService,
)

_logger = logging.getLogger(__name__)

_ON_DEMAND_PATCHED = False
_ORIGINAL_AUTH_OPTION_FOR_IDENTIFIER = None
_ORIGINAL_GET_AUTH_TOKEN_OBJECT = None
_TOKEN_ERROR_PATCHED = False
_OPENID_PATCHED = False
_REFRESH_TOKEN_TENANT_PATCHED = False
_PROMPT_LOGIN_PATCHED = False
_JWKS_TTL_PATCHED = False
_ORIGINAL_STORE_REFRESH_TOKEN = None
_ORIGINAL_GET_REFRESH_TOKEN = None
_MISSING = object()
MASTER_REALM_IDENTIFIER = "master"

CACHE_TTL_SECONDS = 300
_ENDPOINT_CACHE_TIMESTAMPS: dict[str, float] = {}
_JWKS_CACHE_TIMESTAMPS: dict[str, float] = {}


def _call_original_auth_option_for_identifier(cls, authentication_identifier: str):
if _ORIGINAL_AUTH_OPTION_FOR_IDENTIFIER is None:
Expand Down Expand Up @@ -155,12 +166,17 @@ def apply_auth_token_error_patch() -> None:
def _patched_open_id_endpoint_for_name(
cls, name: str, authentication_identifier: str, internal: bool = False
) -> str:
"""Same as original but raises OpenIdConnectionError when discovery returns non-200."""
"""Same as original but raises OpenIdConnectionError when discovery returns non-200, with TTL-based cache eviction."""
if authentication_identifier not in cls.ENDPOINT_CACHE:
cls.ENDPOINT_CACHE[authentication_identifier] = {}
if authentication_identifier not in cls.JSON_WEB_KEYSET_CACHE:
cls.JSON_WEB_KEYSET_CACHE[authentication_identifier] = {}

cached_ts = _ENDPOINT_CACHE_TIMESTAMPS.get(authentication_identifier, 0)
cache_expired = (time.monotonic() - cached_ts) > CACHE_TTL_SECONDS
if cache_expired and cls.ENDPOINT_CACHE[authentication_identifier]:
cls.ENDPOINT_CACHE[authentication_identifier] = {}

internal_server_url = cls.server_url(authentication_identifier, internal=True)
openid_config_url = f"{internal_server_url}/.well-known/openid-configuration"
if name not in cls.ENDPOINT_CACHE[authentication_identifier]:
Expand All @@ -173,6 +189,7 @@ def _patched_open_id_endpoint_for_name(
f"Body: {(response.text or '')[:200]}"
)
cls.ENDPOINT_CACHE[authentication_identifier] = response.json()
_ENDPOINT_CACHE_TIMESTAMPS[authentication_identifier] = time.monotonic()
except requests.exceptions.ConnectionError as ce:
raise OpenIdConnectionError(f"Cannot connect to given open id url: {openid_config_url}") from ce
if name not in cls.ENDPOINT_CACHE[authentication_identifier]:
Expand Down Expand Up @@ -468,3 +485,47 @@ def apply_refresh_token_tenant_patch() -> None:
AuthenticationService.store_refresh_token = staticmethod(_patched_store_refresh_token)
AuthenticationService.get_refresh_token = staticmethod(_patched_get_refresh_token)
_REFRESH_TOKEN_TENANT_PATCHED = True


def apply_prompt_login_patch() -> None:
"""Append prompt=login to authorization URLs so Keycloak always shows a fresh login form."""
global _PROMPT_LOGIN_PATCHED
if _PROMPT_LOGIN_PATCHED:
return

original = AuthenticationService.get_login_redirect_url

@wraps(original)
def _patched_get_login_redirect_url(self, authentication_identifier: str, final_url: str | None = None) -> str:
url = original(self, authentication_identifier, final_url)
if "prompt=" not in url:
url += "&prompt=login"
return url

AuthenticationService.get_login_redirect_url = _patched_get_login_redirect_url
_PROMPT_LOGIN_PATCHED = True
_logger.info("prompt_login_patch: authorization URLs will include prompt=login")


def apply_jwks_cache_ttl_patch() -> None:
"""Add TTL-based eviction to AuthenticationService.get_jwks_config_from_uri."""
global _JWKS_TTL_PATCHED
if _JWKS_TTL_PATCHED:
return

original = AuthenticationService.get_jwks_config_from_uri

@classmethod # type: ignore[misc]
def _patched_get_jwks_config_from_uri(cls, jwks_uri: str, force_refresh: bool = False):
cached_ts = _JWKS_CACHE_TIMESTAMPS.get(jwks_uri, 0)
if (time.monotonic() - cached_ts) > CACHE_TTL_SECONDS:
force_refresh = True

result = original.__func__(cls, jwks_uri, force_refresh=force_refresh)
if force_refresh or jwks_uri not in _JWKS_CACHE_TIMESTAMPS:
_JWKS_CACHE_TIMESTAMPS[jwks_uri] = time.monotonic()
return result

AuthenticationService.get_jwks_config_from_uri = _patched_get_jwks_config_from_uri
_JWKS_TTL_PATCHED = True
_logger.info("jwks_cache_ttl_patch: JWKS cache entries expire after %ds", CACHE_TTL_SECONDS)
8 changes: 8 additions & 0 deletions extensions/startup/patch_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def apply_patch_specs(
target="m8flow_backend.routes.authentication_controller_patch:apply_refresh_token_tenant_patch",
minimum_phase=BootPhase.APP_CREATED,
),
PatchSpec(
target="m8flow_backend.services.authentication_service_patch:apply_prompt_login_patch",
minimum_phase=BootPhase.APP_CREATED,
),
PatchSpec(
target="m8flow_backend.services.authentication_service_patch:apply_jwks_cache_ttl_patch",
minimum_phase=BootPhase.APP_CREATED,
),
)


Expand Down
Loading