Skip to content
Merged
2 changes: 1 addition & 1 deletion datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ class U:
>>> dj.U().aggr(expr, n='count(*)')

The following expressions both yield one element containing the number `n` of distinct values of attribute `attr` in
query expressio `expr`.
query expression `expr`.

>>> dj.U().aggr(expr, n='count(distinct attr)')
>>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
Expand Down
28 changes: 26 additions & 2 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def delete(
transaction: bool = True,
safemode: Union[bool, None] = None,
force_parts: bool = False,
include_parts: bool = True,
) -> int:
"""
Deletes the contents of the table and its dependent tables, recursively.
Expand All @@ -497,6 +498,8 @@ def delete(
safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
is `dj.config['safemode']`.
force_parts: Delete from parts even when not deleting from their masters.
include_parts: If `True`, include part/master pairs in the cascade.
Default is `True`.

Returns:
Number of deleted rows (excluding those from dependent tables).
Expand All @@ -507,6 +510,7 @@ def delete(
DataJointError: Deleting a part table before its master.
"""
deleted = set()
visited_masters = set()

def cascade(table):
"""service function to perform cascading deletes recursively."""
Expand Down Expand Up @@ -565,11 +569,31 @@ def cascade(table):
)
else:
child &= table.proj()
cascade(child)

master_name = get_master(child.full_table_name)
if (
include_parts
and master_name
and master_name != table.full_table_name
and master_name not in visited_masters
):
master = FreeTable(table.connection, master_name)
master._restriction_attributes = set()
master._restriction = [
make_condition( # &= may cause in target tables in subquery
master,
(master.proj() & child.proj()).fetch(),
master._restriction_attributes,
)
]
visited_masters.add(master_name)
cascade(master)
else:
cascade(child)
else:
deleted.add(table.full_table_name)
logger.info(
"Deleting {count} rows from {table}".format(
"Deleting: {count} rows from {table}".format(
count=delete_count, table=table.full_table_name
)
)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def schema_simp(connection_test, prefix):
schema(schema_simple.E)
schema(schema_simple.F)
schema(schema_simple.F)
schema(schema_simple.G)
schema(schema_simple.DataA)
schema(schema_simple.DataB)
schema(schema_simple.Website)
Expand Down
40 changes: 34 additions & 6 deletions tests/schema_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,36 @@ class F(dj.Part):
-> B.C
"""

class G(dj.Part):
definition = """ # test secondary fk reference
-> E
id_g :int
---
-> L
"""

class H(dj.Part):
definition = """ # test no additional fk reference
-> E
id_h :int
"""

def make(self, key):
random.seed(str(key))
self.insert1(dict(key, **random.choice(list(L().fetch("KEY")))))
sub = E.F()
references = list((B.C() & key).fetch("KEY"))
random.shuffle(references)
sub.insert(
l_contents = list(L().fetch("KEY"))
part_f, part_g, part_h = E.F(), E.G(), E.H()
bc_references = list((B.C() & key).fetch("KEY"))
random.shuffle(bc_references)

self.insert1(dict(key, **random.choice(l_contents)))
part_f.insert(
dict(key, id_f=i, **ref)
for i, ref in enumerate(references)
for i, ref in enumerate(bc_references)
if random.getrandbits(1)
)
g_inserts = [dict(key, id_g=i, **ref) for i, ref in enumerate(l_contents)]
part_g.insert(g_inserts)
part_h.insert(dict(key, id_h=i) for i in range(4))


class F(dj.Manual):
Expand All @@ -132,6 +151,15 @@ class F(dj.Manual):
"""


class G(dj.Computed):
definition = """ # test downstream of complex master/parts
-> E
"""

def make(self, key):
self.insert1(key)


class DataA(dj.Lookup):
definition = """
idx : int
Expand Down
22 changes: 18 additions & 4 deletions tests/test_cascading_delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import datajoint as dj
from .schema_simple import A, B, D, E, L, Website, Profile
from .schema_simple import A, B, D, E, G, L, Website, Profile
from .schema import ComplexChild, ComplexParent


Expand All @@ -11,6 +11,7 @@ def schema_simp_pop(schema_simp):
B().populate()
D().populate()
E().populate()
G().populate()
yield schema_simp


Expand Down Expand Up @@ -96,7 +97,7 @@ def test_delete_complex_keys(schema_any):
**{
"child_id_{}".format(i + 1): (i + parent_key_count)
for i in range(child_key_count)
}
},
)
assert len(ComplexParent & restriction) == 1, "Parent record missing"
assert len(ComplexChild & restriction) == 1, "Child record missing"
Expand All @@ -110,11 +111,24 @@ def test_delete_master(schema_simp_pop):
Profile().delete()


def test_delete_parts(schema_simp_pop):
def test_delete_parts_error(schema_simp_pop):
"""test issue #151"""
with pytest.raises(dj.DataJointError):
Profile().populate_random()
Website().delete()
Website().delete(include_parts=False)


def test_delete_parts(schema_simp_pop):
"""test issue #151"""
Profile().populate_random()
Website().delete(include_parts=True)


def test_delete_parts_complex(schema_simp_pop):
"""test issue #151 with complex master/part. PR #1158."""
prev_len = len(G())
(A() & "id_a=1").delete()
assert prev_len - len(G()) == 16, "Failed to delete parts"


def test_drop_part(schema_simp_pop):
Expand Down
12 changes: 8 additions & 4 deletions tests/test_erd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datajoint as dj
from .schema_simple import LOCALS_SIMPLE, A, B, D, E, L, OutfitLaunch
from .schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, OutfitLaunch
from .schema_advanced import *


Expand All @@ -20,7 +20,7 @@ def test_dependencies(schema_simp):
assert set(D().parents(primary=True)) == set([A.full_table_name])
assert set(D().parents(primary=False)) == set([L.full_table_name])
assert set(deps.descendants(L.full_table_name)).issubset(
cls.full_table_name for cls in (L, D, E, E.F)
cls.full_table_name for cls in (L, D, E, E.F, E.G, E.H, G)
)


Expand All @@ -38,10 +38,14 @@ def test_erd_algebra(schema_simp):
erd3 = erd1 * erd2
erd4 = (erd0 + E).add_parts() - B - E
assert erd0.nodes_to_show == set(cls.full_table_name for cls in [B])
assert erd1.nodes_to_show == set(cls.full_table_name for cls in (B, B.C, E, E.F))
assert erd1.nodes_to_show == set(
cls.full_table_name for cls in (B, B.C, E, E.F, E.G, E.H, G)
)
assert erd2.nodes_to_show == set(cls.full_table_name for cls in (A, B, D, E, L))
assert erd3.nodes_to_show == set(cls.full_table_name for cls in (B, E))
assert erd4.nodes_to_show == set(cls.full_table_name for cls in (B.C, E.F))
assert erd4.nodes_to_show == set(
cls.full_table_name for cls in (B.C, E.F, E.G, E.H)
)


def test_repr_svg(schema_adv):
Expand Down
9 changes: 7 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_list_tables(schema_simp):
"""
https://github.com/datajoint/datajoint-python/issues/838
"""
assert set(
expected = set(
[
"reserved_word",
"#l",
Expand All @@ -194,6 +194,9 @@ def test_list_tables(schema_simp):
"__b__c",
"__e",
"__e__f",
"__e__g",
"__e__h",
"__g",
"#outfit_launch",
"#outfit_launch__outfit_piece",
"#i_j",
Expand All @@ -207,7 +210,9 @@ def test_list_tables(schema_simp):
"profile",
"profile__website",
]
) == set(schema_simp.list_tables())
)
actual = set(schema_simp.list_tables())
assert actual == expected, f"Missing from list_tables(): {expected - actual}"


def test_schema_save_any(schema_any):
Expand Down