Skip to content

Commit 868acf5

Browse files
galvanaAdrian Galvan
andauthored
Traversal optimizations (#7244)
Co-authored-by: Adrian Galvan <galvana@uci.edu>
1 parent 498adf8 commit 868acf5

File tree

7 files changed

+1372
-35
lines changed

7 files changed

+1372
-35
lines changed

changelog/7244.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
type: Changed
2+
description: Optimized graph traversal algorithms for improved performance with large dataset graphs
3+
pr: 7244
4+
labels: []

src/fides/api/graph/node_filters.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,19 @@ def _compute_reachable_nodes(
109109

110110
try:
111111
# Create a traversal object with just this identity
112+
# Skip verification in the constructor and explicitly call traverse()
113+
# This avoids the overhead of _verify_traversal() when we're going
114+
# to traverse anyway to detect unreachable nodes.
112115
with suppress_logging():
113116
# Suppress the logs since we don't want to flood the logs
114117
# with traversal info for each identity we want to evaluate
115-
BaseTraversal(self.graph, {identity_key: "dummy_value"})
118+
traversal = BaseTraversal(
119+
self.graph,
120+
{identity_key: "dummy_value"},
121+
skip_verification=True,
122+
)
123+
# Now explicitly traverse to detect unreachable nodes
124+
traversal.traverse({}, lambda n, m: None)
116125

117126
# If successful, all nodes are reachable
118127
self.reachable_by_identity[identity_key] = all_addresses

src/fides/api/graph/traversal.py

Lines changed: 286 additions & 29 deletions
Large diffs are not rendered by default.

src/fides/api/task/create_request_tasks.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,46 @@ def build_consent_networkx_digraph(
190190
return networkx_graph
191191

192192

193+
def compute_all_descendants(
194+
graph: networkx.DiGraph,
195+
) -> Dict[CollectionAddress, Set[CollectionAddress]]:
196+
"""
197+
Compute descendants for ALL nodes in O(N+E) using reverse topological order.
198+
199+
This is much more efficient than calling networkx.descendants() for each node,
200+
which would be O(N * (N+E)) = O(N²) for a graph with N nodes.
201+
202+
By processing in reverse topological order (leaves first), we can compute
203+
each node's descendants as the union of its children's descendants plus
204+
its direct children.
205+
206+
Returns a Dict mapping each node (CollectionAddress) in the graph to the
207+
set of all nodes that are reachable from it (i.e. its transitive successors).
208+
This is used to populate the ``all_descendant_tasks`` field on each
209+
RequestTask so that any node can quickly determine every downstream task
210+
that must complete before the overall request is finished.
211+
"""
212+
all_descendants: Dict[CollectionAddress, Set[CollectionAddress]] = {
213+
node: set() for node in graph.nodes
214+
}
215+
216+
# Process nodes in reverse topological order (leaves first)
217+
for node in reversed(list(networkx.topological_sort(graph))):
218+
# This node's descendants = union of (each child + child's descendants)
219+
for child in graph.successors(node):
220+
all_descendants[node].add(child)
221+
all_descendants[node].update(all_descendants[child])
222+
223+
return all_descendants
224+
225+
193226
def base_task_data(
194227
graph: networkx.DiGraph,
195228
dataset_graph: DatasetGraph,
196229
privacy_request: PrivacyRequest,
197230
node: CollectionAddress,
198231
traversal_nodes: Dict[CollectionAddress, TraversalNode],
232+
all_descendants: Dict[CollectionAddress, Set[CollectionAddress]],
199233
) -> Dict:
200234
"""Build a dictionary of common RequestTask attributes that are shared for building
201235
access, consent, and erasure tasks"""
@@ -236,7 +270,7 @@ def base_task_data(
236270
[downstream.value for downstream in graph.successors(node)]
237271
),
238272
"all_descendant_tasks": sorted(
239-
[descend.value for descend in list(networkx.descendants(graph, node))]
273+
[descend.value for descend in all_descendants.get(node, set())]
240274
),
241275
"collection_address": node.value,
242276
"dataset_name": node.dataset,
@@ -270,6 +304,9 @@ def persist_new_access_request_tasks(
270304
traversal_nodes, end_nodes, traversal
271305
)
272306

307+
# Pre-compute all descendants in O(N+E) instead of O(N²)
308+
all_descendants = compute_all_descendants(graph)
309+
273310
for node in list(networkx.topological_sort(graph)):
274311
if privacy_request.get_existing_request_task(
275312
session, action_type=ActionType.access, collection_address=node
@@ -280,7 +317,12 @@ def persist_new_access_request_tasks(
280317
session,
281318
data={
282319
**base_task_data(
283-
graph, dataset_graph, privacy_request, node, traversal_nodes
320+
graph,
321+
dataset_graph,
322+
privacy_request,
323+
node,
324+
traversal_nodes,
325+
all_descendants,
284326
),
285327
"access_data": (
286328
[traversal.seed_data] if node == ROOT_COLLECTION_ADDRESS else []
@@ -314,6 +356,9 @@ def persist_initial_erasure_request_tasks(
314356
)
315357
graph: networkx.DiGraph = build_erasure_networkx_digraph(traversal_nodes, end_nodes)
316358

359+
# Pre-compute all descendants in O(N+E) instead of O(N²)
360+
all_descendants = compute_all_descendants(graph)
361+
317362
for node in list(networkx.topological_sort(graph)):
318363
if privacy_request.get_existing_request_task(
319364
session, action_type=ActionType.erasure, collection_address=node
@@ -324,7 +369,12 @@ def persist_initial_erasure_request_tasks(
324369
session,
325370
data={
326371
**base_task_data(
327-
graph, dataset_graph, privacy_request, node, traversal_nodes
372+
graph,
373+
dataset_graph,
374+
privacy_request,
375+
node,
376+
traversal_nodes,
377+
all_descendants,
328378
),
329379
"action_type": ActionType.erasure,
330380
},
@@ -406,6 +456,9 @@ def persist_new_consent_request_tasks(
406456
"""
407457
graph: networkx.DiGraph = build_consent_networkx_digraph(traversal_nodes)
408458

459+
# Pre-compute all descendants in O(N+E) instead of O(N²)
460+
all_descendants = compute_all_descendants(graph)
461+
409462
for node in list(networkx.topological_sort(graph)):
410463
if privacy_request.get_existing_request_task(
411464
session, action_type=ActionType.consent, collection_address=node
@@ -415,7 +468,12 @@ def persist_new_consent_request_tasks(
415468
session,
416469
data={
417470
**base_task_data(
418-
graph, dataset_graph, privacy_request, node, traversal_nodes
471+
graph,
472+
dataset_graph,
473+
privacy_request,
474+
node,
475+
traversal_nodes,
476+
all_descendants,
419477
),
420478
# Consent nodes take in identity data from their upstream root node
421479
"access_data": ([identity] if node == ROOT_COLLECTION_ADDRESS else []),

src/fides/config/execution_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,8 @@ class ExecutionSettings(FidesSettings):
8989
default=False,
9090
description="Whether the memory watchdog is enabled to monitor and gracefully terminate tasks that approach memory limits.",
9191
)
92+
use_legacy_traversal: bool = Field(
93+
default=False,
94+
description="When enabled, falls back to the legacy traversal algorithm. Intended as a temporary safety net in case of regressions with the optimized traversal.",
95+
)
9296
model_config = SettingsConfigDict(env_prefix=ENV_PREFIX)

0 commit comments

Comments
 (0)