Skip to content

Commit f5f1f1a

Browse files
authored
fix(attack-paths): recover graph_data_ready when scan fails during graph swap (#10354)
1 parent 65e745d commit f5f1f1a

File tree

6 files changed

+861
-54
lines changed

6 files changed

+861
-54
lines changed

api/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ All notable changes to the **Prowler API** are documented in this file.
1313
- Attack Paths: Complete migration to private graph labels and properties, removing deprecated dual-write support [(#10268)](https://github.com/prowler-cloud/prowler/pull/10268)
1414
- Attack Paths: Added tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308)
1515

16+
### 🐞 Fixed
17+
18+
- Attack Paths: Recover `graph_data_ready` flag when scan fails during graph swap, preventing query endpoints from staying blocked until the next successful scan [(#10354)](https://github.com/prowler-cloud/prowler/pull/10354)
19+
1620
### 🔐 Security
1721

1822
- Use `psycopg2.sql` to safely compose DDL in `PostgresEnumMigration`, preventing SQL injection via f-string interpolation [(#10166)](https://github.com/prowler-cloud/prowler/pull/10166)

api/src/backend/api/attack_paths/database.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
import atexit
22
import logging
33
import threading
4-
5-
from typing import Any
6-
74
from contextlib import contextmanager
8-
from typing import Iterator
5+
from typing import Any, Iterator
96
from uuid import UUID
107

118
import neo4j
129
import neo4j.exceptions
13-
14-
from django.conf import settings
15-
16-
from api.attack_paths.retryable_session import RetryableSession
1710
from config.env import env
11+
from django.conf import settings
1812
from tasks.jobs.attack_paths.config import (
1913
BATCH_SIZE,
2014
PROVIDER_ID_PROPERTY,
2115
PROVIDER_RESOURCE_LABEL,
2216
)
2317

18+
from api.attack_paths.retryable_session import RetryableSession
19+
2420
# Without this Celery goes crazy with Neo4j logging
2521
logging.getLogger("neo4j").setLevel(logging.ERROR)
2622
logging.getLogger("neo4j").propagate = False
@@ -197,6 +193,29 @@ def drop_subgraph(database: str, provider_id: str) -> int:
197193
return deleted_nodes
198194

199195

196+
def has_provider_data(database: str, provider_id: str) -> bool:
197+
"""
198+
Check if any ProviderResource node exists for this provider.
199+
200+
Returns `False` if the database doesn't exist.
201+
"""
202+
query = (
203+
f"MATCH (n:{PROVIDER_RESOURCE_LABEL} "
204+
f"{{{PROVIDER_ID_PROPERTY}: $provider_id}}) "
205+
"RETURN 1 LIMIT 1"
206+
)
207+
208+
try:
209+
with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session:
210+
result = session.run(query, {"provider_id": provider_id})
211+
return result.single() is not None
212+
213+
except GraphDatabaseQueryException as exc:
214+
if exc.code == "Neo.ClientError.Database.DatabaseNotFound":
215+
return False
216+
raise
217+
218+
200219
def clear_cache(database: str) -> None:
201220
query = "CALL db.clearQueryCaches()"
202221

api/src/backend/api/tests/test_attack_paths_database.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,78 @@ def call_init():
442442
# All threads got the same driver instance
443443
assert all(r is mock_driver for r in results)
444444
assert len(results) == 10
445+
446+
447+
class TestHasProviderData:
448+
"""Test has_provider_data helper for checking provider nodes in Neo4j."""
449+
450+
def test_returns_true_when_nodes_exist(self):
451+
import api.attack_paths.database as db_module
452+
453+
mock_session = MagicMock()
454+
mock_result = MagicMock()
455+
mock_result.single.return_value = MagicMock() # non-None record
456+
mock_session.run.return_value = mock_result
457+
458+
session_ctx = MagicMock()
459+
session_ctx.__enter__.return_value = mock_session
460+
session_ctx.__exit__.return_value = False
461+
462+
with patch(
463+
"api.attack_paths.database.get_session",
464+
return_value=session_ctx,
465+
):
466+
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
467+
468+
mock_session.run.assert_called_once()
469+
470+
def test_returns_false_when_no_nodes(self):
471+
import api.attack_paths.database as db_module
472+
473+
mock_session = MagicMock()
474+
mock_result = MagicMock()
475+
mock_result.single.return_value = None
476+
mock_session.run.return_value = mock_result
477+
478+
session_ctx = MagicMock()
479+
session_ctx.__enter__.return_value = mock_session
480+
session_ctx.__exit__.return_value = False
481+
482+
with patch(
483+
"api.attack_paths.database.get_session",
484+
return_value=session_ctx,
485+
):
486+
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is False
487+
488+
def test_returns_false_when_database_not_found(self):
489+
import api.attack_paths.database as db_module
490+
491+
session_ctx = MagicMock()
492+
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
493+
message="Database does not exist",
494+
code="Neo.ClientError.Database.DatabaseNotFound",
495+
)
496+
497+
with patch(
498+
"api.attack_paths.database.get_session",
499+
return_value=session_ctx,
500+
):
501+
assert (
502+
db_module.has_provider_data("db-tenant-gone", "provider-123") is False
503+
)
504+
505+
def test_raises_on_other_errors(self):
506+
import api.attack_paths.database as db_module
507+
508+
session_ctx = MagicMock()
509+
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
510+
message="Connection refused",
511+
code="Neo.TransientError.General.UnknownError",
512+
)
513+
514+
with patch(
515+
"api.attack_paths.database.get_session",
516+
return_value=session_ctx,
517+
):
518+
with pytest.raises(db_module.GraphDatabaseQueryException):
519+
db_module.has_provider_data("db-tenant-abc", "provider-123")

api/src/backend/tasks/jobs/attack_paths/db_utils.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33

44
from cartography.config import Config as CartographyConfig
55
from celery.utils.log import get_task_logger
6+
from tasks.jobs.attack_paths.config import is_provider_available
67

78
from api.attack_paths import database as graph_database
89
from api.db_utils import rls_transaction
9-
from api.models import (
10-
AttackPathsScan as ProwlerAPIAttackPathsScan,
11-
Provider as ProwlerAPIProvider,
12-
StateChoices,
13-
)
14-
from tasks.jobs.attack_paths.config import is_provider_available
10+
from api.models import AttackPathsScan as ProwlerAPIAttackPathsScan
11+
from api.models import Provider as ProwlerAPIProvider
12+
from api.models import StateChoices
1513

1614
logger = get_task_logger(__name__)
1715

@@ -155,6 +153,37 @@ def set_provider_graph_data_ready(
155153
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
156154

157155

156+
def recover_graph_data_ready(
157+
attack_paths_scan: ProwlerAPIAttackPathsScan,
158+
) -> None:
159+
"""
160+
Best-effort recovery of `graph_data_ready` after a scan failure.
161+
162+
Queries Neo4j to check if the provider still has data in the tenant
163+
database. If data exists, restores `graph_data_ready=True` for all scans
164+
of this provider. Never raises.
165+
166+
Trade-off: if the worker crashed mid-sync, partial data may exist and
167+
this will re-enable queries against it. We accept that because leaving
168+
`graph_data_ready=False` permanently (blocking all queries until the
169+
next successful scan) is a worse outcome for the user.
170+
"""
171+
try:
172+
tenant_db = graph_database.get_database_name(attack_paths_scan.tenant_id)
173+
if graph_database.has_provider_data(
174+
tenant_db, str(attack_paths_scan.provider_id)
175+
):
176+
set_provider_graph_data_ready(attack_paths_scan, True)
177+
logger.info(
178+
f"Recovered `graph_data_ready` for provider {attack_paths_scan.provider_id}"
179+
)
180+
181+
except Exception:
182+
logger.exception(
183+
f"Failed to recover `graph_data_ready` for provider {attack_paths_scan.provider_id}"
184+
)
185+
186+
158187
def fail_attack_paths_scan(
159188
tenant_id: str,
160189
scan_id: str,
@@ -185,3 +214,5 @@ def fail_attack_paths_scan(
185214
StateChoices.FAILED,
186215
{"global_error": error},
187216
)
217+
218+
recover_graph_data_ready(attack_paths_scan)

api/src/backend/tasks/jobs/attack_paths/scan.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,21 @@
5555

5656
import logging
5757
import time
58-
5958
from typing import Any
6059

6160
from cartography.config import Config as CartographyConfig
6261
from cartography.intel import analysis as cartography_analysis
6362
from cartography.intel import create_indexes as cartography_create_indexes
6463
from cartography.intel import ontology as cartography_ontology
6564
from celery.utils.log import get_task_logger
65+
from tasks.jobs.attack_paths import db_utils, findings, internet, sync, utils
66+
from tasks.jobs.attack_paths.config import get_cartography_ingestion_function
6667

6768
from api.attack_paths import database as graph_database
6869
from api.db_utils import rls_transaction
69-
from api.models import (
70-
Provider as ProwlerAPIProvider,
71-
StateChoices,
72-
)
70+
from api.models import Provider as ProwlerAPIProvider
71+
from api.models import StateChoices
7372
from api.utils import initialize_prowler_provider
74-
from tasks.jobs.attack_paths import db_utils, findings, internet, sync, utils
75-
from tasks.jobs.attack_paths.config import get_cartography_ingestion_function
7673

7774
# Without this Celery goes crazy with Cartography logging
7875
logging.getLogger("cartography").setLevel(logging.ERROR)
@@ -147,6 +144,10 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
147144
attack_paths_scan, task_id, tenant_cartography_config
148145
)
149146

147+
subgraph_dropped = False
148+
sync_completed = False
149+
provider_gated = False
150+
150151
try:
151152
logger.info(
152153
f"Creating Neo4j database {tmp_cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}"
@@ -225,10 +226,12 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
225226

226227
logger.info(f"Deleting existing provider graph in {tenant_database_name}")
227228
db_utils.set_provider_graph_data_ready(attack_paths_scan, False)
229+
provider_gated = True
228230
graph_database.drop_subgraph(
229231
database=tenant_database_name,
230232
provider_id=str(prowler_api_provider.id),
231233
)
234+
subgraph_dropped = True
232235
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 98)
233236

234237
logger.info(
@@ -240,6 +243,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
240243
tenant_id=str(prowler_api_provider.tenant_id),
241244
provider_id=str(prowler_api_provider.id),
242245
)
246+
sync_completed = True
243247
db_utils.set_graph_data_ready(attack_paths_scan, True)
244248
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 99)
245249

@@ -264,23 +268,39 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
264268
logger.exception(exception_message)
265269
ingestion_exceptions["global_error"] = exception_message
266270

267-
# Handling databases changes
271+
# Recover graph_data_ready based on how far the swap got.
272+
# Partial drop (mid-batch failure) may leave `subgraph_dropped=False`
273+
# with data partially deleted, so we prefer that over permanently blocked queries.
274+
try:
275+
if sync_completed:
276+
db_utils.set_graph_data_ready(attack_paths_scan, True)
277+
elif provider_gated and not subgraph_dropped:
278+
db_utils.set_provider_graph_data_ready(attack_paths_scan, True)
279+
280+
except Exception:
281+
logger.error(
282+
f"Failed to recover `graph_data_ready` for provider {attack_paths_scan.provider_id}",
283+
exc_info=True,
284+
)
285+
286+
# Dropping the temporary database if it still exists
268287
try:
269288
graph_database.drop_database(tmp_cartography_config.neo4j_database)
270289

271290
except Exception as e:
272291
logger.error(
273-
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup: {e}",
292+
f"Failed to drop temporary Neo4j database `{tmp_cartography_config.neo4j_database}` during cleanup: {e}",
274293
exc_info=True,
275294
)
276295

296+
# Set Attack Paths scan state to FAILED
277297
try:
278298
db_utils.finish_attack_paths_scan(
279299
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
280300
)
281301
except Exception as e:
282302
logger.error(
283-
f"Could not mark attack paths scan {attack_paths_scan.id} as FAILED (row may have been deleted): {e}",
303+
f"Could not mark Attack Paths scan {attack_paths_scan.id} as `FAILED` (row may have been deleted): {e}",
284304
exc_info=True,
285305
)
286306

0 commit comments

Comments
 (0)