Skip to content

Commit e47f2b4

Browse files
authored
fix(api): harden security hub retries (#10144)
1 parent 7077a56 commit e47f2b4

File tree

5 files changed

+260
-77
lines changed

5 files changed

+260
-77
lines changed

api/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ All notable changes to the **Prowler API** are documented in this file.
3535

3636
- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101)
3737
- Attack Paths: scan no longer raises `DatabaseError` when provider is deleted mid-scan [(#10116)](https://github.com/prowler-cloud/prowler/pull/10116)
38+
- Security Hub export retries transient replica conflicts without failing integrations [(#10144)](https://github.com/prowler-cloud/prowler/pull/10144)
3839

3940
### 🔐 Security
4041

api/src/backend/api/db_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def rls_transaction(
7474
value: str,
7575
parameter: str = POSTGRES_TENANT_VAR,
7676
using: str | None = None,
77+
retry_on_replica: bool = True,
7778
):
7879
"""
7980
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
@@ -92,10 +93,11 @@ def rls_transaction(
9293

9394
alias = db_alias
9495
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
95-
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1
96+
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
9697

9798
for attempt in range(1, max_attempts + 1):
9899
router_token = None
100+
yielded_cursor = False
99101

100102
# On final attempt, fallback to primary
101103
if attempt == max_attempts and is_replica:
@@ -118,9 +120,12 @@ def rls_transaction(
118120
except ValueError:
119121
raise ValidationError("Must be a valid UUID")
120122
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
123+
yielded_cursor = True
121124
yield cursor
122125
return
123126
except OperationalError as e:
127+
if yielded_cursor:
128+
raise
124129
# If on primary or max attempts reached, raise
125130
if not is_replica or attempt == max_attempts:
126131
raise

api/src/backend/api/tests/test_db_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,36 @@ def atomic_side_effect(*args, **kwargs):
550550
mock_sleep.assert_any_call(1.0)
551551
assert mock_logger.info.call_count == 2
552552

553+
def test_rls_transaction_operational_error_inside_context_no_retry(
554+
self, tenants_fixture, enable_read_replica
555+
):
556+
"""Test OperationalError raised inside context does not retry."""
557+
tenant = tenants_fixture[0]
558+
tenant_id = str(tenant.id)
559+
560+
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
561+
with patch("api.db_utils.connections") as mock_connections:
562+
mock_conn = MagicMock()
563+
mock_cursor = MagicMock()
564+
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
565+
mock_connections.__getitem__.return_value = mock_conn
566+
mock_connections.__contains__.return_value = True
567+
568+
with patch("api.db_utils.transaction.atomic") as mock_atomic:
569+
mock_atomic.return_value.__enter__.return_value = None
570+
mock_atomic.return_value.__exit__.return_value = False
571+
572+
with patch("api.db_utils.time.sleep") as mock_sleep:
573+
with patch(
574+
"api.db_utils.set_read_db_alias", return_value="token"
575+
):
576+
with patch("api.db_utils.reset_read_db_alias"):
577+
with pytest.raises(OperationalError):
578+
with rls_transaction(tenant_id):
579+
raise OperationalError("Conflict with recovery")
580+
581+
mock_sleep.assert_not_called()
582+
553583
def test_rls_transaction_max_three_attempts_for_replica(
554584
self, tenants_fixture, enable_read_replica
555585
):
@@ -579,6 +609,38 @@ def test_rls_transaction_max_three_attempts_for_replica(
579609

580610
assert mock_atomic.call_count == 3
581611

612+
def test_rls_transaction_replica_no_retry_when_disabled(
613+
self, tenants_fixture, enable_read_replica
614+
):
615+
"""Test replica retry is disabled when retry_on_replica=False."""
616+
tenant = tenants_fixture[0]
617+
tenant_id = str(tenant.id)
618+
619+
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
620+
with patch("api.db_utils.connections") as mock_connections:
621+
mock_conn = MagicMock()
622+
mock_cursor = MagicMock()
623+
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
624+
mock_connections.__getitem__.return_value = mock_conn
625+
mock_connections.__contains__.return_value = True
626+
627+
with patch("api.db_utils.transaction.atomic") as mock_atomic:
628+
mock_atomic.side_effect = OperationalError("Replica error")
629+
630+
with patch("api.db_utils.time.sleep") as mock_sleep:
631+
with patch(
632+
"api.db_utils.set_read_db_alias", return_value="token"
633+
):
634+
with patch("api.db_utils.reset_read_db_alias"):
635+
with pytest.raises(OperationalError):
636+
with rls_transaction(
637+
tenant_id, retry_on_replica=False
638+
):
639+
pass
640+
641+
assert mock_atomic.call_count == 1
642+
mock_sleep.assert_not_called()
643+
582644
def test_rls_transaction_only_one_attempt_for_primary(self, tenants_fixture):
583645
"""Test only 1 attempt for primary database."""
584646
tenant = tenants_fixture[0]

api/src/backend/tasks/jobs/integrations.py

Lines changed: 112 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
2+
import time
23
from glob import glob
34

45
from celery.utils.log import get_task_logger
56
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
7+
from django.db import OperationalError
68
from tasks.utils import batched
79

810
from api.db_router import READ_REPLICA_ALIAS, MainRouter
9-
from api.db_utils import rls_transaction
11+
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
1012
from api.models import Finding, Integration, Provider
1113
from api.utils import initialize_prowler_integration, initialize_prowler_provider
1214
from prowler.lib.outputs.asff.asff import ASFF
@@ -17,11 +19,11 @@
1719
from prowler.lib.outputs.ocsf.ocsf import OCSF
1820
from prowler.providers.aws.aws_provider import AwsProvider
1921
from prowler.providers.aws.lib.s3.s3 import S3
20-
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
21-
from prowler.providers.common.models import Connection
2222
from prowler.providers.aws.lib.security_hub.exceptions.exceptions import (
2323
SecurityHubNoEnabledRegionsError,
2424
)
25+
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
26+
from prowler.providers.common.models import Connection
2527

2628
logger = get_task_logger(__name__)
2729

@@ -291,96 +293,130 @@ def upload_security_hub_integration(
291293
total_findings_sent[integration.id] = 0
292294

293295
# Process findings in batches to avoid memory issues
296+
max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1
294297
has_findings = False
295298
batch_number = 0
296299

297-
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
298-
qs = (
299-
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
300-
.order_by("uid")
301-
.iterator()
302-
)
303-
304-
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
305-
batch_number += 1
306-
has_findings = True
300+
for attempt in range(1, max_attempts + 1):
301+
read_alias = None
302+
if READ_REPLICA_ALIAS:
303+
read_alias = (
304+
READ_REPLICA_ALIAS
305+
if attempt < max_attempts
306+
else MainRouter.default_db
307+
)
307308

308-
# Transform findings for this batch
309-
transformed_findings = [
310-
FindingOutput.transform_api_finding(
311-
finding, prowler_provider
309+
try:
310+
batch_number = 0
311+
has_findings = False
312+
with rls_transaction(
313+
tenant_id,
314+
using=read_alias,
315+
retry_on_replica=False,
316+
):
317+
qs = (
318+
Finding.all_objects.filter(
319+
tenant_id=tenant_id, scan_id=scan_id
320+
)
321+
.order_by("uid")
322+
.iterator()
312323
)
313-
for finding in batch
314-
]
315-
316-
# Convert to ASFF format
317-
asff_transformer = ASFF(
318-
findings=transformed_findings,
319-
file_path="",
320-
file_extension="json",
321-
)
322-
asff_transformer.transform(transformed_findings)
323324

324-
# Get the batch of ASFF findings
325-
batch_asff_findings = asff_transformer.data
325+
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
326+
batch_number += 1
327+
has_findings = True
326328

327-
if batch_asff_findings:
328-
# Create Security Hub client for first batch or reuse existing
329-
if not security_hub_client:
330-
connected, security_hub = (
331-
get_security_hub_client_from_integration(
332-
integration, tenant_id, batch_asff_findings
329+
# Transform findings for this batch
330+
transformed_findings = [
331+
FindingOutput.transform_api_finding(
332+
finding, prowler_provider
333333
)
334+
for finding in batch
335+
]
336+
337+
# Convert to ASFF format
338+
asff_transformer = ASFF(
339+
findings=transformed_findings,
340+
file_path="",
341+
file_extension="json",
334342
)
343+
asff_transformer.transform(transformed_findings)
344+
345+
# Get the batch of ASFF findings
346+
batch_asff_findings = asff_transformer.data
347+
348+
if batch_asff_findings:
349+
# Create Security Hub client for first batch or reuse existing
350+
if not security_hub_client:
351+
connected, security_hub = (
352+
get_security_hub_client_from_integration(
353+
integration,
354+
tenant_id,
355+
batch_asff_findings,
356+
)
357+
)
335358

336-
if not connected:
337-
if isinstance(
338-
security_hub.error,
339-
SecurityHubNoEnabledRegionsError,
340-
):
341-
logger.warning(
342-
f"Security Hub integration {integration.id} has no enabled regions"
359+
if not connected:
360+
if isinstance(
361+
security_hub.error,
362+
SecurityHubNoEnabledRegionsError,
363+
):
364+
logger.warning(
365+
f"Security Hub integration {integration.id} has no enabled regions"
366+
)
367+
else:
368+
logger.error(
369+
f"Security Hub connection failed for integration {integration.id}: "
370+
f"{security_hub.error}"
371+
)
372+
break # Skip this integration
373+
374+
security_hub_client = security_hub
375+
logger.info(
376+
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
377+
f"integration {integration.id}"
343378
)
344379
else:
345-
logger.error(
346-
f"Security Hub connection failed for integration {integration.id}: "
347-
f"{security_hub.error}"
380+
# Update findings in existing client for this batch
381+
security_hub_client._findings_per_region = (
382+
security_hub_client.filter(
383+
batch_asff_findings,
384+
send_only_fails,
385+
)
348386
)
349-
break # Skip this integration
350387

351-
security_hub_client = security_hub
352-
logger.info(
353-
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
354-
f"integration {integration.id}"
355-
)
356-
else:
357-
# Update findings in existing client for this batch
358-
security_hub_client._findings_per_region = (
359-
security_hub_client.filter(
360-
batch_asff_findings, send_only_fails
361-
)
362-
)
388+
# Send this batch to Security Hub
389+
try:
390+
findings_sent = security_hub_client.batch_send_to_security_hub()
391+
total_findings_sent[integration.id] += (
392+
findings_sent
393+
)
363394

364-
# Send this batch to Security Hub
365-
try:
366-
findings_sent = (
367-
security_hub_client.batch_send_to_security_hub()
368-
)
369-
total_findings_sent[integration.id] += findings_sent
395+
if findings_sent > 0:
396+
logger.debug(
397+
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
398+
)
399+
except Exception as batch_error:
400+
logger.error(
401+
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
402+
)
370403

371-
if findings_sent > 0:
372-
logger.debug(
373-
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
374-
)
375-
except Exception as batch_error:
376-
logger.error(
377-
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
378-
)
404+
# Clear memory after processing each batch
405+
asff_transformer._data.clear()
406+
del batch_asff_findings
407+
del transformed_findings
408+
409+
break
410+
except OperationalError as e:
411+
if attempt == max_attempts:
412+
raise
379413

380-
# Clear memory after processing each batch
381-
asff_transformer._data.clear()
382-
del batch_asff_findings
383-
del transformed_findings
414+
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
415+
logger.info(
416+
"RLS query failed during Security Hub integration "
417+
f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}"
418+
)
419+
time.sleep(delay)
384420

385421
if not has_findings:
386422
logger.info(

0 commit comments

Comments
 (0)