Skip to content

Commit 559a87d

Browse files
feat(data-modeling): add cycle detection / dag mismatch detection into edge model directly (#42423)
1 parent 10cb2de commit 559a87d

File tree

5 files changed

+387
-9
lines changed

5 files changed

+387
-9
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# manually created by andrewjmcgehee
2+
3+
from django.db import migrations
4+
5+
DROP_DETECT_CYCLES = """\
6+
DROP FUNCTION IF EXISTS posthog_datamodelingedge_detect_cycles();
7+
"""
8+
9+
10+
class Migration(migrations.Migration):
11+
dependencies = [
12+
("data_modeling", "0003_create_detect_cycles_function"),
13+
]
14+
operations = [migrations.RunSQL(DROP_DETECT_CYCLES)]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0003_create_detect_cycles_function
1+
0004_drop_detect_cycles_function
Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,62 @@
1-
from django.db import models
1+
from django.db import connection, models, transaction
22

3+
from posthog.models import Team
34
from posthog.models.utils import CreatedMetaFields, UpdatedMetaFields, UUIDModel
45

6+
from .node import Node
7+
8+
DISALLOWED_UPDATE_FIELDS = ("dag_id", "source", "source_id", "target", "target_id", "team", "team_id")
9+
10+
11+
class CycleDetectionError(Exception):
12+
"""The exception raised when an edge would cause a cycle in a DAG"""
13+
14+
pass
15+
16+
17+
class DAGMismatchError(Exception):
18+
"""exception raised when an edge would connect two different DAGs together"""
19+
20+
pass
21+
22+
23+
class DataModelingEdgeQuerySet(models.QuerySet):
24+
def update(self, **kwargs):
25+
for key in DISALLOWED_UPDATE_FIELDS:
26+
if key in kwargs:
27+
raise NotImplementedError(
28+
f"QuerySet.update() is disabled for fields ({DISALLOWED_UPDATE_FIELDS}) to ensure cycle detection. "
29+
"Use individual save() calls instead."
30+
)
31+
return super().update(**kwargs)
32+
33+
def bulk_create(self, objs, *args, **kwargs):
34+
del objs, args, kwargs # unused
35+
raise NotImplementedError("bulk_create() is disabled for Edge objects to ensure cycle detection.")
36+
37+
def bulk_update(self, objs, fields, *args, **kwargs):
38+
for key in DISALLOWED_UPDATE_FIELDS:
39+
if key in kwargs:
40+
raise NotImplementedError(
41+
f"QuerySet.bulk_update() is disabled for fields ({DISALLOWED_UPDATE_FIELDS}) to ensure cycle detection. "
42+
"Use individual save() calls instead."
43+
)
44+
return super().bulk_update(objs, fields, *args, **kwargs)
45+
46+
47+
class DataModelingEdgeManager(models.Manager):
48+
def get_queryset(self):
49+
return DataModelingEdgeQuerySet(self.model, using=self._db)
50+
551

652
class Edge(UUIDModel, CreatedMetaFields, UpdatedMetaFields):
7-
team = models.ForeignKey("posthog.Team", on_delete=models.CASCADE, editable=False)
53+
objects = DataModelingEdgeManager()
54+
55+
team = models.ForeignKey(Team, on_delete=models.CASCADE, editable=False)
856
# the source node of the edge (i.e. the node this edge is pointed away from)
9-
source = models.ForeignKey(
10-
"data_modeling.Node", related_name="outgoing_edges", on_delete=models.CASCADE, editable=False
11-
)
57+
source = models.ForeignKey(Node, related_name="outgoing_edges", on_delete=models.CASCADE, editable=False)
1258
# the target node of the edge (i.e. the node this edge is pointed toward)
13-
target = models.ForeignKey(
14-
"data_modeling.Node", related_name="incoming_edges", on_delete=models.CASCADE, editable=False
15-
)
59+
target = models.ForeignKey(Node, related_name="incoming_edges", on_delete=models.CASCADE, editable=False)
1660
# the name of the DAG this edge belongs to
1761
dag_id = models.TextField(max_length=256, default="posthog", db_index=True, editable=False)
1862
properties = models.JSONField(default=dict)
@@ -22,3 +66,65 @@ class Meta:
2266
constraints = [
2367
models.UniqueConstraint(fields=["dag_id", "source", "target"], name="unique_within_dag"),
2468
]
69+
70+
def save(self, *args, **kwargs):
71+
with transaction.atomic():
72+
self._detect_cycles()
73+
self._detect_dag_mismatch()
74+
super().save(*args, **kwargs)
75+
76+
def _detect_cycles(self):
77+
with connection.cursor() as cursor:
78+
cursor.execute("SELECT pg_advisory_xact_lock(%s, hashtext(%s))", [self.team_id, self.dag_id])
79+
# trivial case: self loop
80+
if self.source_id == self.target_id:
81+
raise CycleDetectionError(
82+
f"Self-loop detected: team={self.team_id} dag={self.dag_id} "
83+
f"source={self.source_id} target={self.target_id}"
84+
)
85+
# recursive case
86+
if self._creates_cycle():
87+
raise CycleDetectionError(
88+
f"Cycle detected: team={self.team_id} dag={self.dag_id} source={self.source_id} target={self.target_id}"
89+
)
90+
91+
def _creates_cycle(self):
92+
sql = """
93+
WITH RECURSIVE reachable(node_id) AS (
94+
SELECT e.target_id
95+
FROM posthog_datamodelingedge e
96+
WHERE e.source_id = '{target_id}'
97+
AND e.team_id = '{team_id}'
98+
AND e.dag_id = '{dag_id}'
99+
UNION
100+
SELECT e.target_id
101+
FROM posthog_datamodelingedge e
102+
INNER JOIN reachable r
103+
ON e.source_id = r.node_id
104+
WHERE e.target_id <> '{target_id}'
105+
AND e.team_id = '{team_id}'
106+
AND e.dag_id = '{dag_id}'
107+
)
108+
SELECT 1 FROM reachable WHERE node_id = '{source_id}'
109+
"""
110+
with connection.cursor() as cursor:
111+
cursor.execute(
112+
sql.format(team_id=self.team_id, dag_id=self.dag_id, source_id=self.source_id, target_id=self.target_id)
113+
)
114+
return cursor.fetchone() is not None
115+
116+
def _detect_dag_mismatch(self):
117+
source = Node.objects.get(id=self.source_id)
118+
target = Node.objects.get(id=self.target_id)
119+
if source.team_id != self.team_id or target.team_id != self.team_id:
120+
raise DAGMismatchError(
121+
f"Edge team_id ({self.team_id}) does not match "
122+
f"source node team_id ({source.team_id}) or "
123+
f"target node team_id ({target.team_id})"
124+
)
125+
if source.dag_id != self.dag_id or target.dag_id != self.dag_id:
126+
raise DAGMismatchError(
127+
f"Edge dag_id ({self.dag_id}) does not match "
128+
f"source node dag_id ({source.dag_id}) or "
129+
f"target node dag_id ({target.dag_id})"
130+
)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from freezegun import freeze_time
2+
from posthog.test.base import BaseTest
3+
4+
from parameterized import parameterized
5+
6+
from products.data_modeling.backend.models import CycleDetectionError, Edge, Node
7+
from products.data_warehouse.backend.models import DataWarehouseSavedQuery
8+
9+
LINKED_LIST_DAG_ID = "linked_list"
10+
BALANCED_TREE_DAG_ID = "balanced_tree"
11+
12+
13+
def _basic_saved_query_with_label(label: str):
14+
return f"""SELECT '{label}'"""
15+
16+
17+
class LinkedListCycleDetectionTest(BaseTest):
18+
@classmethod
19+
def setUpTestData(cls):
20+
super().setUpTestData()
21+
with freeze_time("2025-01-01T12:00:00.000Z"):
22+
ll_queries = [
23+
DataWarehouseSavedQuery.objects.create(
24+
name=f"ll_{i}",
25+
team=cls.team,
26+
query=_basic_saved_query_with_label(str(i)),
27+
)
28+
for i in range(25)
29+
]
30+
ll_nodes = [
31+
Node.objects.create(team=cls.team, dag_id=LINKED_LIST_DAG_ID, saved_query=query, name=f"ll_{i}")
32+
for i, query in enumerate(ll_queries)
33+
]
34+
for i in range(len(ll_nodes) - 1):
35+
Edge.objects.create(
36+
team=cls.team,
37+
dag_id=LINKED_LIST_DAG_ID,
38+
source=ll_nodes[i],
39+
target=ll_nodes[i + 1],
40+
)
41+
42+
@parameterized.expand(
43+
[
44+
# low index to high index is always fine
45+
("0", "24", False),
46+
("0", "5", False),
47+
("0", "2", False), # note 0, 1 would be a duplicate edge and would fail
48+
# self cycle case
49+
("0", "0", True),
50+
# high index to low always causes a cycle
51+
("1", "0", True),
52+
("5", "0", True),
53+
("24", "0", True),
54+
],
55+
)
56+
def test_linked_list_dag(self, source_label, target_label, should_raise):
57+
source = Node.objects.get(saved_query__name=f"ll_{source_label}")
58+
target = Node.objects.get(saved_query__name=f"ll_{target_label}")
59+
if should_raise:
60+
with self.assertRaises(Exception):
61+
Edge.objects.create(team=source.team, dag_id=LINKED_LIST_DAG_ID, source=source, target=target)
62+
else:
63+
edge = Edge.objects.create(team=source.team, dag_id=LINKED_LIST_DAG_ID, source=source, target=target)
64+
edge.delete()
65+
66+
67+
class TreeCycleDetectionTest(BaseTest):
68+
@classmethod
69+
def setUpTestData(cls):
70+
super().setUpTestData()
71+
with freeze_time("2025-01-01T12:00:00.000Z"):
72+
bt_root = [
73+
DataWarehouseSavedQuery.objects.create(
74+
name="bt_root",
75+
team=cls.team,
76+
query=_basic_saved_query_with_label("root"),
77+
)
78+
]
79+
bt_children = [
80+
DataWarehouseSavedQuery.objects.create(
81+
name=f"bt_child_{i}",
82+
team=cls.team,
83+
query=_basic_saved_query_with_label(f"child_{i}"),
84+
)
85+
for i in range(5)
86+
]
87+
bt_grandchildren = [
88+
DataWarehouseSavedQuery.objects.create(
89+
name=f"bt_child_{i}_child_{j}",
90+
team=cls.team,
91+
query=_basic_saved_query_with_label(f"child_{i}_child_{j}"),
92+
)
93+
for i in range(5)
94+
for j in range(5)
95+
]
96+
bt_nodes = [
97+
Node.objects.create(team=cls.team, dag_id=BALANCED_TREE_DAG_ID, saved_query=query, name=query.name)
98+
for query in bt_root + bt_children + bt_grandchildren
99+
]
100+
root = bt_nodes[0]
101+
children = bt_nodes[1:6]
102+
for i, child in enumerate(children):
103+
Edge.objects.create(team=cls.team, dag_id=BALANCED_TREE_DAG_ID, source=root, target=child)
104+
for j in range(5):
105+
grandchild = bt_nodes[6 + i * 5 + j]
106+
Edge.objects.create(
107+
team=cls.team,
108+
dag_id=BALANCED_TREE_DAG_ID,
109+
source=child,
110+
target=grandchild,
111+
)
112+
113+
@parameterized.expand(
114+
[
115+
# root to any grandchild is always fine
116+
("root", "child_0_child_0", False),
117+
("root", "child_2_child_0", False),
118+
("root", "child_4_child_0", False),
119+
# child to any grandchild is always fine
120+
("child_0", "child_1_child_0", False),
121+
("child_0", "child_2_child_0", False),
122+
("child_0", "child_4_child_0", False),
123+
# self cycle case
124+
("root", "root", True),
125+
# child to root always causes a cycle
126+
("child_0", "root", True),
127+
("child_2", "root", True),
128+
("child_4", "root", True),
129+
# grandchild to root always causes a cycle
130+
("child_0_child_0", "root", True),
131+
("child_2_child_0", "root", True),
132+
("child_4_child_0", "root", True),
133+
# grandchild to any child not its parent is always fine (i.e. root -> 0 -> 0_0 -> 1 + root -> 1 = no cycle)
134+
("child_1_child_0", "child_0", False),
135+
("child_2_child_0", "child_0", False),
136+
("child_4_child_0", "child_0", False),
137+
# any grandchild to its parent always causes a cycle
138+
("child_0_child_0", "child_0", True),
139+
("child_0_child_2", "child_0", True),
140+
("child_0_child_4", "child_0", True),
141+
# any child to any other child is always fine
142+
("child_0", "child_1", False),
143+
("child_0", "child_2", False),
144+
("child_0", "child_4", False),
145+
# any grandchild to any other grandchild is always fine
146+
("child_0_child_0", "child_1_child_1", False),
147+
("child_0_child_0", "child_2_child_2", False),
148+
("child_0_child_0", "child_4_child_4", False),
149+
],
150+
)
151+
def test_tree_like_dag(self, source_label, target_label, should_raise):
152+
source = Node.objects.get(saved_query__name=f"bt_{source_label}")
153+
target = Node.objects.get(saved_query__name=f"bt_{target_label}")
154+
if should_raise:
155+
with self.assertRaises(CycleDetectionError):
156+
Edge.objects.create(team=source.team, dag_id=BALANCED_TREE_DAG_ID, source=source, target=target)
157+
else:
158+
edge = Edge.objects.create(team=source.team, dag_id=BALANCED_TREE_DAG_ID, source=source, target=target)
159+
edge.delete()
160+
161+
def test_disallowed_object_functions(self):
162+
test_team = self.team
163+
test_node = Node.objects.get(name="bt_root")
164+
bt_edges = Edge.objects.filter(dag_id=BALANCED_TREE_DAG_ID)
165+
disallowed = ("dag_id", "source", "source_id", "target", "target_id", "team", "team_id")
166+
for key in disallowed:
167+
# test update disallowed for each key
168+
with self.assertRaises(NotImplementedError):
169+
if key.endswith("id"):
170+
bt_edges.update(**{key: "test"})
171+
elif key == "source":
172+
bt_edges.update(source=test_node)
173+
elif key == "target":
174+
bt_edges.update(target=test_node)
175+
elif key == "team":
176+
bt_edges.update(team=test_team)
177+
# test bulk_update disallowed for each key
178+
mock_edges = [Edge(source=test_node, target=test_node, team=test_team, dag_id="test") for _ in range(3)]
179+
for edge in mock_edges:
180+
if key.endswith("id"):
181+
setattr(edge, key, "test")
182+
elif key in ("source", "target"):
183+
setattr(edge, key, test_node)
184+
elif key == "team":
185+
setattr(edge, key, test_team)
186+
with self.assertRaises(NotImplementedError):
187+
Edge.objects.bulk_update(mock_edges, [key])
188+
# test bulk_create disallowed
189+
with self.assertRaises(NotImplementedError):
190+
Edge.objects.bulk_create(bt_edges)

0 commit comments

Comments
 (0)