Skip to content

Commit 5351980

Browse files
fchiricaaltendky
andauthored
DL: remove data from DB on unsubscribe (#16786)
Co-authored-by: Kyle Altendorf <[email protected]>
1 parent fab8b21 commit 5351980

File tree

4 files changed

+282
-5
lines changed

4 files changed

+282
-5
lines changed

chia/data_layer/data_layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,12 @@ async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> None
647647
async with self.subscription_lock:
648648
await self.data_store.remove_subscriptions(store_id, parsed_urls)
649649

650-
async def unsubscribe(self, tree_id: bytes32, retain_files: bool) -> None:
650+
async def unsubscribe(self, tree_id: bytes32, retain_data: bool) -> None:
651651
subscriptions = await self.get_subscriptions()
652652
if tree_id not in (subscription.tree_id for subscription in subscriptions):
653653
raise RuntimeError("No subscription found for the given tree_id.")
654654
filenames: List[str] = []
655-
if await self.data_store.tree_id_exists(tree_id) and not retain_files:
655+
if await self.data_store.tree_id_exists(tree_id) and not retain_data:
656656
generation = await self.data_store.get_tree_generation(tree_id)
657657
all_roots = await self.data_store.get_roots_between(tree_id, 1, generation + 1)
658658
for root in all_roots:
@@ -663,6 +663,9 @@ async def unsubscribe(self, tree_id: bytes32, retain_files: bool) -> None:
663663
await self.wallet_rpc.dl_stop_tracking(tree_id)
664664
async with self.subscription_lock:
665665
await self.data_store.unsubscribe(tree_id)
666+
if not retain_data:
667+
await self.data_store.delete_store_data(tree_id)
668+
666669
self.log.info(f"Unsubscribed to {tree_id}")
667670
for filename in filenames:
668671
file_path = self.server_files_location.joinpath(filename)

chia/data_layer/data_store.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,13 +1073,35 @@ async def delete(
10731073

10741074
return new_root
10751075

1076+
async def clean_node_table(self, writer: aiosqlite.Connection) -> None:
1077+
await writer.execute(
1078+
"""
1079+
WITH RECURSIVE pending_nodes AS (
1080+
SELECT node_hash AS hash FROM root
1081+
WHERE status = ?
1082+
UNION ALL
1083+
SELECT n.left FROM node n
1084+
INNER JOIN pending_nodes pn ON n.hash = pn.hash
1085+
WHERE n.left IS NOT NULL
1086+
UNION ALL
1087+
SELECT n.right FROM node n
1088+
INNER JOIN pending_nodes pn ON n.hash = pn.hash
1089+
WHERE n.right IS NOT NULL
1090+
)
1091+
DELETE FROM node
1092+
WHERE hash NOT IN (SELECT hash FROM ancestors)
1093+
AND hash NOT IN (SELECT hash FROM pending_nodes)
1094+
""",
1095+
(Status.PENDING.value,),
1096+
)
1097+
10761098
async def insert_batch(
10771099
self,
10781100
tree_id: bytes32,
10791101
changelist: List[Dict[str, Any]],
10801102
status: Status = Status.PENDING,
10811103
) -> Optional[bytes32]:
1082-
async with self.db_wrapper.writer():
1104+
async with self.db_wrapper.writer() as writer:
10831105
old_root = await self.get_tree_root(tree_id)
10841106
root_hash = old_root.node_hash
10851107
if old_root.node_hash is None:
@@ -1146,6 +1168,8 @@ async def insert_batch(
11461168
"Didn't get the expected generation after batch update: "
11471169
f"Expected: {old_root.generation + 1}. Got: {new_root.generation}"
11481170
)
1171+
1172+
await self.clean_node_table(writer)
11491173
return root.node_hash
11501174

11511175
async def _get_one_ancestor(
@@ -1443,6 +1467,74 @@ async def remove_subscriptions(self, tree_id: bytes32, urls: List[str]) -> None:
14431467
},
14441468
)
14451469

1470+
async def delete_store_data(self, tree_id: bytes32) -> None:
1471+
async with self.db_wrapper.writer() as writer:
1472+
await self.clean_node_table(writer)
1473+
cursor = await writer.execute(
1474+
"""
1475+
WITH RECURSIVE all_nodes AS (
1476+
SELECT a.hash, n.left, n.right
1477+
FROM ancestors AS a
1478+
JOIN node AS n ON a.hash = n.hash
1479+
WHERE a.tree_id = :tree_id
1480+
),
1481+
pending_nodes AS (
1482+
SELECT node_hash AS hash FROM root
1483+
WHERE status = :status
1484+
UNION ALL
1485+
SELECT n.left FROM node n
1486+
INNER JOIN pending_nodes pn ON n.hash = pn.hash
1487+
WHERE n.left IS NOT NULL
1488+
UNION ALL
1489+
SELECT n.right FROM node n
1490+
INNER JOIN pending_nodes pn ON n.hash = pn.hash
1491+
WHERE n.right IS NOT NULL
1492+
)
1493+
1494+
SELECT hash, left, right
1495+
FROM all_nodes
1496+
WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id)
1497+
AND hash NOT IN (SELECT hash from pending_nodes)
1498+
""",
1499+
{"tree_id": tree_id, "status": Status.PENDING.value},
1500+
)
1501+
to_delete: Dict[bytes, Tuple[bytes, bytes]] = {}
1502+
ref_counts: Dict[bytes, int] = {}
1503+
async for row in cursor:
1504+
hash = row["hash"]
1505+
left = row["left"]
1506+
right = row["right"]
1507+
if hash in to_delete:
1508+
prev_left, prev_right = to_delete[hash]
1509+
assert prev_left == left
1510+
assert prev_right == right
1511+
continue
1512+
to_delete[hash] = (left, right)
1513+
if left is not None:
1514+
ref_counts[left] = ref_counts.get(left, 0) + 1
1515+
if right is not None:
1516+
ref_counts[right] = ref_counts.get(right, 0) + 1
1517+
1518+
await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (tree_id,))
1519+
await writer.execute("DELETE FROM root WHERE tree_id == ?", (tree_id,))
1520+
queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0]
1521+
while queue:
1522+
hash = queue.pop(0)
1523+
if hash not in to_delete:
1524+
continue
1525+
await writer.execute("DELETE FROM node WHERE hash == ?", (hash,))
1526+
1527+
left, right = to_delete[hash]
1528+
if left is not None:
1529+
ref_counts[left] -= 1
1530+
if ref_counts[left] == 0:
1531+
queue.append(left)
1532+
1533+
if right is not None:
1534+
ref_counts[right] -= 1
1535+
if ref_counts[right] == 0:
1536+
queue.append(right)
1537+
14461538
async def unsubscribe(self, tree_id: bytes32) -> None:
14471539
async with self.db_wrapper.writer() as writer:
14481540
await writer.execute(

chia/rpc/data_layer_rpc_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,13 @@ async def unsubscribe(self, request: Dict[str, Any]) -> EndpointResult:
286286
unsubscribe from singleton
287287
"""
288288
store_id = request.get("id")
289-
retain_files = request.get("retain", False)
289+
retain_data = request.get("retain", False)
290290
if store_id is None:
291291
raise Exception("missing store id in request")
292292
if self.service is None:
293293
raise Exception("Data layer not created")
294294
store_id_bytes = bytes32.from_hexstr(store_id)
295-
await self.service.unsubscribe(store_id_bytes, retain_files)
295+
await self.service.unsubscribe(store_id_bytes, retain_data)
296296
return {}
297297

298298
async def subscriptions(self, request: Dict[str, Any]) -> EndpointResult:

tests/core/data_layer/test_data_store.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,3 +1428,185 @@ async def test_benchmark_batch_insert_speed(
14281428
tree_id=tree_id,
14291429
changelist=batch,
14301430
)
1431+
1432+
1433+
@pytest.mark.anyio
1434+
async def test_delete_store_data(raw_data_store: DataStore) -> None:
1435+
tree_id = bytes32(b"\0" * 32)
1436+
tree_id_2 = bytes32(b"\0" * 31 + b"\1")
1437+
await raw_data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
1438+
await raw_data_store.create_tree(tree_id=tree_id_2, status=Status.COMMITTED)
1439+
total_keys = 4
1440+
keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)]
1441+
batch1 = [
1442+
{"action": "insert", "key": keys[0], "value": keys[0]},
1443+
{"action": "insert", "key": keys[1], "value": keys[1]},
1444+
]
1445+
batch2 = batch1.copy()
1446+
batch1.append({"action": "insert", "key": keys[2], "value": keys[2]})
1447+
batch2.append({"action": "insert", "key": keys[3], "value": keys[3]})
1448+
assert batch1 != batch2
1449+
await raw_data_store.insert_batch(tree_id, batch1, status=Status.COMMITTED)
1450+
await raw_data_store.insert_batch(tree_id_2, batch2, status=Status.COMMITTED)
1451+
keys_values_before = await raw_data_store.get_keys_values(tree_id_2)
1452+
async with raw_data_store.db_wrapper.reader() as reader:
1453+
result = await reader.execute("SELECT * FROM node")
1454+
nodes = await result.fetchall()
1455+
kv_nodes_before = {}
1456+
for node in nodes:
1457+
if node["key"] is not None:
1458+
kv_nodes_before[node["key"]] = node["value"]
1459+
assert [kv_nodes_before[key] for key in keys] == keys
1460+
await raw_data_store.delete_store_data(tree_id)
1461+
# Deleting from `node` table doesn't alter other stores.
1462+
keys_values_after = await raw_data_store.get_keys_values(tree_id_2)
1463+
assert keys_values_before == keys_values_after
1464+
async with raw_data_store.db_wrapper.reader() as reader:
1465+
result = await reader.execute("SELECT * FROM node")
1466+
nodes = await result.fetchall()
1467+
kv_nodes_after = {}
1468+
for node in nodes:
1469+
if node["key"] is not None:
1470+
kv_nodes_after[node["key"]] = node["value"]
1471+
for i in range(total_keys):
1472+
if i != 2:
1473+
assert kv_nodes_after[keys[i]] == keys[i]
1474+
else:
1475+
# `keys[2]` was only present in the first store.
1476+
assert keys[i] not in kv_nodes_after
1477+
assert not await raw_data_store.tree_id_exists(tree_id)
1478+
await raw_data_store.delete_store_data(tree_id_2)
1479+
async with raw_data_store.db_wrapper.reader() as reader:
1480+
async with reader.execute("SELECT COUNT(*) FROM node") as cursor:
1481+
row_count = await cursor.fetchone()
1482+
assert row_count is not None
1483+
assert row_count[0] == 0
1484+
assert not await raw_data_store.tree_id_exists(tree_id_2)
1485+
1486+
1487+
@pytest.mark.anyio
1488+
async def test_delete_store_data_multiple_stores(raw_data_store: DataStore) -> None:
1489+
# Make sure inserting and deleting the same data works
1490+
for repetition in range(2):
1491+
num_stores = 50
1492+
total_keys = 150
1493+
keys_deleted_per_store = 3
1494+
tree_ids = [bytes32(i.to_bytes(32, byteorder="big")) for i in range(num_stores)]
1495+
for tree_id in tree_ids:
1496+
await raw_data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
1497+
original_keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)]
1498+
batches = []
1499+
for i in range(num_stores):
1500+
batch = [
1501+
{"action": "insert", "key": key, "value": key} for key in original_keys[i * keys_deleted_per_store :]
1502+
]
1503+
batches.append(batch)
1504+
1505+
for tree_id, batch in zip(tree_ids, batches):
1506+
await raw_data_store.insert_batch(tree_id, batch, status=Status.COMMITTED)
1507+
1508+
for tree_index in range(num_stores):
1509+
async with raw_data_store.db_wrapper.reader() as reader:
1510+
result = await reader.execute("SELECT * FROM node")
1511+
nodes = await result.fetchall()
1512+
1513+
keys = {node["key"] for node in nodes if node["key"] is not None}
1514+
assert len(keys) == total_keys - tree_index * keys_deleted_per_store
1515+
keys_after_index = set(original_keys[tree_index * keys_deleted_per_store :])
1516+
keys_before_index = set(original_keys[: tree_index * keys_deleted_per_store])
1517+
assert keys_after_index.issubset(keys)
1518+
assert keys.isdisjoint(keys_before_index)
1519+
await raw_data_store.delete_store_data(tree_ids[tree_index])
1520+
1521+
async with raw_data_store.db_wrapper.reader() as reader:
1522+
async with reader.execute("SELECT COUNT(*) FROM node") as cursor:
1523+
row_count = await cursor.fetchone()
1524+
assert row_count is not None
1525+
assert row_count[0] == 0
1526+
1527+
1528+
@pytest.mark.parametrize("common_keys_count", [1, 250, 499])
1529+
@pytest.mark.anyio
1530+
async def test_delete_store_data_with_common_values(raw_data_store: DataStore, common_keys_count: int) -> None:
1531+
tree_id_1 = bytes32(b"\x00" * 31 + b"\x01")
1532+
tree_id_2 = bytes32(b"\x00" * 31 + b"\x02")
1533+
1534+
await raw_data_store.create_tree(tree_id=tree_id_1, status=Status.COMMITTED)
1535+
await raw_data_store.create_tree(tree_id=tree_id_2, status=Status.COMMITTED)
1536+
1537+
key_offset = 1000
1538+
total_keys_per_store = 500
1539+
assert common_keys_count < key_offset
1540+
common_keys = {key.to_bytes(4, byteorder="big") for key in range(common_keys_count)}
1541+
unique_keys_1 = {
1542+
(key + key_offset).to_bytes(4, byteorder="big") for key in range(total_keys_per_store - common_keys_count)
1543+
}
1544+
unique_keys_2 = {
1545+
(key + (2 * key_offset)).to_bytes(4, byteorder="big") for key in range(total_keys_per_store - common_keys_count)
1546+
}
1547+
1548+
batch1 = [{"action": "insert", "key": key, "value": key} for key in common_keys.union(unique_keys_1)]
1549+
batch2 = [{"action": "insert", "key": key, "value": key} for key in common_keys.union(unique_keys_2)]
1550+
1551+
await raw_data_store.insert_batch(tree_id_1, batch1, status=Status.COMMITTED)
1552+
await raw_data_store.insert_batch(tree_id_2, batch2, status=Status.COMMITTED)
1553+
1554+
await raw_data_store.delete_store_data(tree_id_1)
1555+
async with raw_data_store.db_wrapper.reader() as reader:
1556+
result = await reader.execute("SELECT * FROM node")
1557+
nodes = await result.fetchall()
1558+
1559+
keys = {node["key"] for node in nodes if node["key"] is not None}
1560+
# Since one store got all its keys deleted, we're left only with the keys of the other store.
1561+
assert len(keys) == total_keys_per_store
1562+
assert keys.intersection(unique_keys_1) == set()
1563+
assert keys.symmetric_difference(common_keys.union(unique_keys_2)) == set()
1564+
1565+
1566+
@pytest.mark.anyio
1567+
async def test_delete_store_data_protects_pending_roots(raw_data_store: DataStore) -> None:
1568+
num_stores = 5
1569+
total_keys = 15
1570+
tree_ids = [bytes32(i.to_bytes(32, byteorder="big")) for i in range(num_stores)]
1571+
for tree_id in tree_ids:
1572+
await raw_data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
1573+
original_keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)]
1574+
batches = []
1575+
keys_per_pending_root = 2
1576+
1577+
for i in range(num_stores - 1):
1578+
start_index = i * keys_per_pending_root
1579+
end_index = (i + 1) * keys_per_pending_root
1580+
batch = [{"action": "insert", "key": key, "value": key} for key in original_keys[start_index:end_index]]
1581+
batches.append(batch)
1582+
for tree_id, batch in zip(tree_ids, batches):
1583+
await raw_data_store.insert_batch(tree_id, batch, status=Status.PENDING)
1584+
1585+
tree_id = tree_ids[-1]
1586+
batch = [{"action": "insert", "key": key, "value": key} for key in original_keys]
1587+
await raw_data_store.insert_batch(tree_id, batch, status=Status.COMMITTED)
1588+
1589+
async with raw_data_store.db_wrapper.reader() as reader:
1590+
result = await reader.execute("SELECT * FROM node")
1591+
nodes = await result.fetchall()
1592+
1593+
keys = {node["key"] for node in nodes if node["key"] is not None}
1594+
assert keys == set(original_keys)
1595+
1596+
await raw_data_store.delete_store_data(tree_id)
1597+
async with raw_data_store.db_wrapper.reader() as reader:
1598+
result = await reader.execute("SELECT * FROM node")
1599+
nodes = await result.fetchall()
1600+
1601+
keys = {node["key"] for node in nodes if node["key"] is not None}
1602+
assert keys == set(original_keys[: (num_stores - 1) * keys_per_pending_root])
1603+
1604+
for index in range(num_stores - 1):
1605+
tree_id = tree_ids[index]
1606+
root = await raw_data_store.get_pending_root(tree_id)
1607+
assert root is not None
1608+
await raw_data_store.change_root_status(root, Status.COMMITTED)
1609+
kv = await raw_data_store.get_keys_values(tree_id=tree_id)
1610+
start_index = index * keys_per_pending_root
1611+
end_index = (index + 1) * keys_per_pending_root
1612+
assert {pair.key for pair in kv} == set(original_keys[start_index:end_index])

0 commit comments

Comments
 (0)