Skip to content

Commit 02b58d8

Browse files
authored
fix(api): mark attack paths scan as failed when celery task fails (#10065)
1 parent 3defbcd commit 02b58d8

File tree

6 files changed

+283
-55
lines changed

6 files changed

+283
-55
lines changed

api/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ All notable changes to the **Prowler API** are documented in this file.
1818
- Support CSA CCM 4.0 for the Azure provider [(#10039)](https://github.com/prowler-cloud/prowler/pull/10039)
1919
- Support CSA CCM 4.0 for the Oracle Cloud provider [(#10057)](https://github.com/prowler-cloud/prowler/pull/10057)
2020
- Support CSA CCM 4.0 for the Alibaba Cloud provider [(#10061)](https://github.com/prowler-cloud/prowler/pull/10061)
21+
- Attack Paths: Mark attack Paths scan as failed when Celery task fails outside job error handling [(#10065)](https://github.com/prowler-cloud/prowler/pull/10065)
2122

2223
### 🔐 Security
2324

api/poetry.lock

Lines changed: 8 additions & 48 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/src/backend/tasks/jobs/attack_paths/db_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def finish_attack_paths_scan(
8686
) -> None:
8787
with rls_transaction(attack_paths_scan.tenant_id):
8888
now = datetime.now(tz=timezone.utc)
89-
duration = int((now - attack_paths_scan.started_at).total_seconds())
89+
duration = (
90+
int((now - attack_paths_scan.started_at).total_seconds())
91+
if attack_paths_scan.started_at
92+
else 0
93+
)
9094

9195
attack_paths_scan.state = state
9296
attack_paths_scan.progress = 100
@@ -144,3 +148,24 @@ def update_old_attack_paths_scan(
144148
with rls_transaction(old_attack_paths_scan.tenant_id):
145149
old_attack_paths_scan.is_graph_database_deleted = True
146150
old_attack_paths_scan.save(update_fields=["is_graph_database_deleted"])
151+
152+
153+
def fail_attack_paths_scan(
154+
tenant_id: str,
155+
scan_id: str,
156+
error: str,
157+
) -> None:
158+
"""
159+
Mark the `AttackPathsScan` row as `FAILED` unless it's already `COMPLETED` or `FAILED`.
160+
Used as a safety net when the Celery task fails outside the job's own error handling.
161+
"""
162+
attack_paths_scan = retrieve_attack_paths_scan(tenant_id, scan_id)
163+
if attack_paths_scan and attack_paths_scan.state not in (
164+
StateChoices.COMPLETED,
165+
StateChoices.FAILED,
166+
):
167+
finish_attack_paths_scan(
168+
attack_paths_scan,
169+
StateChoices.FAILED,
170+
{"global_error": error},
171+
)

api/src/backend/tasks/jobs/attack_paths/scan.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,16 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
228228
except Exception as e:
229229
exception_message = utils.stringify_exception(e, "Cartography failed")
230230
logger.error(exception_message)
231-
ingestion_exceptions["global_cartography_error"] = exception_message
231+
ingestion_exceptions["global_error"] = exception_message
232232

233233
# Handling databases changes
234-
graph_database.drop_database(tmp_cartography_config.neo4j_database)
234+
try:
235+
graph_database.drop_database(tmp_cartography_config.neo4j_database)
236+
except Exception:
237+
logger.exception(
238+
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup"
239+
)
240+
235241
db_utils.finish_attack_paths_scan(
236242
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
237243
)

api/src/backend/tasks/tasks.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django_celery_beat.models import PeriodicTask
1111
from tasks.jobs.attack_paths import (
1212
attack_paths_scan,
13+
db_utils as attack_paths_db_utils,
1314
can_provider_run_attack_paths_scan,
1415
)
1516
from tasks.jobs.backfill import (
@@ -359,8 +360,25 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
359360
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
360361

361362

363+
class AttackPathsScanRLSTask(RLSTask):
364+
"""
365+
RLS task that marks the `AttackPathsScan` DB row as `FAILED` when the Celery task fails.
366+
367+
Covers failures that happen outside the job's own try/except (e.g. provider lookup,
368+
SDK initialization, or Neo4j configuration errors during setup).
369+
"""
370+
371+
def on_failure(self, exc, task_id, args, kwargs, _einfo):
372+
tenant_id = kwargs.get("tenant_id")
373+
scan_id = kwargs.get("scan_id")
374+
375+
if tenant_id and scan_id:
376+
logger.error(f"Attack paths scan task {task_id} failed: {exc}")
377+
attack_paths_db_utils.fail_attack_paths_scan(tenant_id, scan_id, str(exc))
378+
379+
362380
@shared_task(
363-
base=RLSTask,
381+
base=AttackPathsScanRLSTask,
364382
bind=True,
365383
name="attack-paths-scan-perform",
366384
queue="attack-paths-scans",

0 commit comments

Comments
 (0)