Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ All notable changes to the **Prowler API** are documented in this file.
### 🐞 Fixed

- Attack Paths: Recover `graph_data_ready` flag when scan fails during graph swap, preventing query endpoints from staying blocked until the next successful scan [(#10354)](https://github.com/prowler-cloud/prowler/pull/10354)
- Fix `rls_transaction` to retry mid-query read replica failures with primary DB fallback via `execute_wrapper`, preventing scan crashes during RDS replica recovery [(#10379)](https://github.com/prowler-cloud/prowler/pull/10379)

### 🔐 Security

Expand Down
158 changes: 137 additions & 21 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import secrets
import sys
import time
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime, timedelta, timezone

from celery.utils.log import get_task_logger
Expand Down Expand Up @@ -78,34 +79,122 @@
retry_on_replica: bool = True,
):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
Context manager that opens an RLS-scoped database transaction.

Sets a Postgres configuration variable (``set_config``) so that
Row-Level Security policies can filter by tenant. When *using*
points to a read replica and *retry_on_replica* is True, two
layers of retry protect against replica failures:

1. **Pre-yield** (connection-setup failures): the function retries
up to ``REPLICA_MAX_ATTEMPTS`` times on the replica, then falls
back to the primary DB.
2. **Post-yield** (mid-query failures): an ``execute_wrapper``
intercepts ``OperationalError`` during ``cursor.execute()``
calls, retries on the replica with backoff, and falls back to
the primary if the replica stays down. The wrapper swaps the
inner psycopg2 cursor so ``fetchall()`` / ``fetchone()`` read
from the new connection transparently.

Limitation: server-side cursors (``.iterator()``) fetch rows via
``fetchmany()``, which the wrapper does not intercept. Call sites
that iterate large result sets with ``.iterator()`` on the replica
should add their own retry logic.

Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
using (str | None): Optional database alias to run the transaction against. Defaults to the
active read alias (if any) or Django's default connection.
value: Database configuration parameter value (must be a valid UUID).
parameter: Database configuration parameter name.
using: Optional database alias. Defaults to the active read
alias or Django's default connection.
retry_on_replica: Whether to retry on replica failures.
"""
requested_alias = using or get_read_db_alias()
db_alias = requested_alias or DEFAULT_DB_ALIAS
if db_alias not in connections:
db_alias = DEFAULT_DB_ALIAS

alias = db_alias
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
is_replica = bool(READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS)
can_failover = is_replica and retry_on_replica
replica_alias = alias # captured before the loop mutates alias
max_attempts = (REPLICA_MAX_ATTEMPTS + 1) if can_failover else 1

for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
# State shared between the generator and the _query_failover closure
_fallback = {"succeeded": False, "atomic": None, "token": None}

def _query_failover(execute, sql, params, many, context):
"""execute_wrapper: retry failed replica queries, then fall back to primary."""
try:
return execute(sql, params, many, context)
except OperationalError as err:
# Phase 1 — retry on replica with exponential backoff
for retry in range(1, REPLICA_MAX_ATTEMPTS + 1):
try:
connections[replica_alias].close()
except Exception:
pass # Best-effort; connection may already be dead

delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1))
logger.info(
f"Mid-query failure on replica (retry {retry}/{REPLICA_MAX_ATTEMPTS}), "
f"retrying in {delay:.1f}s. Error: {err}"
)
time.sleep(delay)

try:
replica_conn = connections[replica_alias]
replica_conn.ensure_connection()
replica_conn.connection.autocommit = False
raw = replica_conn.connection.cursor()
raw.execute(SET_CONFIG_QUERY, [parameter, value])
if many:
raw.executemany(sql, params)
else:
raw.execute(sql, params)
context["cursor"].cursor = raw
return None
except OperationalError as retry_err:
err = retry_err
continue

# Phase 2 — fall back to primary
try:
connections[replica_alias].close()
except Exception:
pass # Best-effort; connection may already be dead

# On final attempt, fallback to primary
if attempt == max_attempts and is_replica:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
"Mid-query replica retries exhausted, falling back to primary DB"
)
primary = connections[DEFAULT_DB_ALIAS]
primary.ensure_connection()
_fallback["atomic"] = transaction.atomic(using=DEFAULT_DB_ALIAS)
_fallback["atomic"].__enter__()
with primary.cursor() as setup_cursor:
setup_cursor.execute(SET_CONFIG_QUERY, [parameter, value])
_fallback["token"] = set_read_db_alias(DEFAULT_DB_ALIAS)

raw = primary.connection.cursor()
if many:
raw.executemany(sql, params)
else:
raw.execute(sql, params)
context["cursor"].cursor = raw
_fallback["succeeded"] = True
return None

for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
_caller_exited_cleanly = False

# On final attempt, fall back to primary
if attempt == max_attempts and can_failover:
if attempt > 1:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
alias = DEFAULT_DB_ALIAS

conn = connections[alias]
Expand All @@ -116,19 +205,36 @@
with transaction.atomic(using=alias):
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yielded_cursor = True
yield cursor

wrapper_cm = (
conn.execute_wrapper(_query_failover)
if can_failover and alias == replica_alias
else nullcontext()
)
with wrapper_cm:
yielded_cursor = True
yield cursor
_caller_exited_cleanly = True
return
except OperationalError as e:
try:
connections[alias].close()
except Exception:
pass # Best-effort; connection may already be dead

if yielded_cursor:
if _fallback["succeeded"] and _caller_exited_cleanly:
# Caller's queries succeeded on primary via failover.
# This error is transaction.atomic() cleanup on the
# dead replica connection — suppress it.
return
raise
# If on primary or max attempts reached, raise
if not is_replica or attempt == max_attempts:

if not can_failover or attempt == max_attempts:
raise

# Retry with exponential backoff
Expand All @@ -139,6 +245,16 @@
)
time.sleep(delay)
finally:
if _fallback["atomic"] is not None:
try:
_fallback["atomic"].__exit__(*sys.exc_info())
except Exception:
pass # Best-effort; primary connection may be dead
_fallback["atomic"] = None
if _fallback["token"] is not None:
reset_read_db_alias(_fallback["token"])
_fallback["token"] = None

if router_token is not None:
reset_read_db_alias(router_token)

Expand Down
Loading
Loading