Skip to content

Commit 3bce7a6

Browse files
committed
Avoid revisit table
1 parent 25d368a commit 3bce7a6

File tree

4 files changed

+70
-26
lines changed

4 files changed

+70
-26
lines changed

datajoint/table.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def delete(
486486
transaction: bool = True,
487487
safemode: Union[bool, None] = None,
488488
force_parts: bool = False,
489-
include_master: bool = True,
489+
include_parts: bool = True,
490490
) -> int:
491491
"""
492492
Deletes the contents of the table and its dependent tables, recursively.
@@ -498,7 +498,8 @@ def delete(
498498
safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
499499
is `dj.config['safemode']`.
500500
force_parts: Delete from parts even when not deleting from their masters.
501-
include_master: If `True`, delete from the master table as well. Default is `True`.
501+
include_parts: If `True`, include part/master pairs in the cascade.
502+
Default is `True`.
502503
503504
Returns:
504505
Number of deleted rows (excluding those from dependent tables).
@@ -509,6 +510,7 @@ def delete(
509510
DataJointError: Deleting a part table before its master.
510511
"""
511512
deleted = set()
513+
visited_masters = set()
512514

513515
def cascade(table):
514516
"""service function to perform cascading deletes recursively."""
@@ -568,25 +570,30 @@ def cascade(table):
568570
else:
569571
child &= table.proj()
570572

571-
master = get_master(child.full_table_name)
572-
if include_master and master and master not in deleted:
573-
master_table = FreeTable(table.connection, master)
574-
master_table._restriction = [
575-
make_condition(
576-
master_table,
577-
(master_table & child).proj().fetch(),
578-
set(),
573+
master_name = get_master(child.full_table_name)
574+
if (
575+
include_parts
576+
and master_name
577+
and master_name != table.full_table_name
578+
and master_name not in visited_masters
579+
):
580+
master = FreeTable(table.connection, master_name)
581+
master._restriction_attributes = set()
582+
master._restriction = [
583+
make_condition( # &= may cause in target tables in subquery
584+
master,
585+
(master.proj() & child.proj()).fetch(),
586+
master._restriction_attributes,
579587
)
580588
]
581-
582-
cascade(child)
583-
584-
if include_master and master and master not in deleted:
585-
cascade(master_table)
589+
visited_masters.add(master_name)
590+
cascade(master)
591+
else:
592+
cascade(child)
586593
else:
587594
deleted.add(table.full_table_name)
588595
logger.info(
589-
"Deleting {count} rows from {table}".format(
596+
"Deleting: {count} rows from {table}".format(
590597
count=delete_count, table=table.full_table_name
591598
)
592599
)

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def schema_simp(connection_test, prefix):
317317
schema(schema_simple.E)
318318
schema(schema_simple.F)
319319
schema(schema_simple.F)
320+
schema(schema_simple.G)
320321
schema(schema_simple.DataA)
321322
schema(schema_simple.DataB)
322323
schema(schema_simple.Website)

tests/schema_simple.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,36 @@ class F(dj.Part):
111111
-> B.C
112112
"""
113113

114+
class G(dj.Part):
115+
definition = """ # test secondary fk reference
116+
-> E
117+
id_g :int
118+
---
119+
-> L
120+
"""
121+
122+
class H(dj.Part):
123+
definition = """ # test no additional fk reference
124+
-> E
125+
id_h :int
126+
"""
127+
114128
def make(self, key):
115129
random.seed(str(key))
116-
self.insert1(dict(key, **random.choice(list(L().fetch("KEY")))))
117-
sub = E.F()
118-
references = list((B.C() & key).fetch("KEY"))
119-
random.shuffle(references)
120-
sub.insert(
130+
l_contents = list(L().fetch("KEY"))
131+
part_f, part_g, part_h = E.F(), E.G(), E.H()
132+
bc_references = list((B.C() & key).fetch("KEY"))
133+
random.shuffle(bc_references)
134+
135+
self.insert1(dict(key, **random.choice(l_contents)))
136+
part_f.insert(
121137
dict(key, id_f=i, **ref)
122-
for i, ref in enumerate(references)
138+
for i, ref in enumerate(bc_references)
123139
if random.getrandbits(1)
124140
)
141+
g_inserts = [dict(key, id_g=i, **ref) for i, ref in enumerate(l_contents)]
142+
part_g.insert(g_inserts)
143+
part_h.insert(dict(key, id_h=i) for i in range(4))
125144

126145

127146
class F(dj.Manual):
@@ -132,6 +151,15 @@ class F(dj.Manual):
132151
"""
133152

134153

154+
class G(dj.Computed):
155+
definition = """ # test downstream of complex master/parts
156+
-> E
157+
"""
158+
159+
def make(self, key):
160+
self.insert1(key)
161+
162+
135163
class DataA(dj.Lookup):
136164
definition = """
137165
idx : int

tests/test_cascading_delete.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import datajoint as dj
3-
from .schema_simple import A, B, D, E, L, Website, Profile
3+
from .schema_simple import A, B, D, E, G, L, Website, Profile
44
from .schema import ComplexChild, ComplexParent
55

66

@@ -11,6 +11,7 @@ def schema_simp_pop(schema_simp):
1111
B().populate()
1212
D().populate()
1313
E().populate()
14+
G().populate()
1415
yield schema_simp
1516

1617

@@ -96,7 +97,7 @@ def test_delete_complex_keys(schema_any):
9697
**{
9798
"child_id_{}".format(i + 1): (i + parent_key_count)
9899
for i in range(child_key_count)
99-
}
100+
},
100101
)
101102
assert len(ComplexParent & restriction) == 1, "Parent record missing"
102103
assert len(ComplexChild & restriction) == 1, "Child record missing"
@@ -114,13 +115,20 @@ def test_delete_parts_error(schema_simp_pop):
114115
"""test issue #151"""
115116
with pytest.raises(dj.DataJointError):
116117
Profile().populate_random()
117-
Website().delete(include_master=False)
118+
Website().delete(include_parts=False)
118119

119120

120121
def test_delete_parts(schema_simp_pop):
121122
"""test issue #151"""
122123
Profile().populate_random()
123-
Website().delete(include_master=True)
124+
Website().delete(include_parts=True)
125+
126+
127+
def test_delete_parts_complex(schema_simp_pop):
128+
"""test issue #151 with complex master/part. PR #1158."""
129+
prev_len = len(G())
130+
(A() & "id_a=1").delete()
131+
assert prev_len - len(G()) == 16, "Failed to delete parts"
124132

125133

126134
def test_drop_part(schema_simp_pop):

0 commit comments

Comments
 (0)