Skip to content

Commit 99ee3e4

Browse files
authored
add DataStore.managed() (#16890)
1 parent 4f55ffb commit 99ee3e4

File tree

5 files changed

+115
-131
lines changed

5 files changed

+115
-131
lines changed

chia/data_layer/data_layer.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,27 +188,25 @@ async def manage(self) -> AsyncIterator[None]:
188188
sql_log_path = path_from_root(self.root_path, "log/data_sql.log")
189189
self.log.info(f"logging SQL commands to {sql_log_path}")
190190

191-
self._data_store = await DataStore.create(database=self.db_path, sql_log_path=sql_log_path)
192-
self._wallet_rpc = await self.wallet_rpc_init
193-
194-
self.periodically_manage_data_task = asyncio.create_task(self.periodically_manage_data())
195-
try:
196-
yield
197-
finally:
198-
# TODO: review for anything else we need to do here
199-
self._shut_down = True
200-
if self._wallet_rpc is not None:
201-
self.wallet_rpc.close()
202-
203-
if self.periodically_manage_data_task is not None:
204-
try:
205-
self.periodically_manage_data_task.cancel()
206-
except asyncio.CancelledError:
207-
pass
208-
if self._data_store is not None:
209-
await self.data_store.close()
210-
if self._wallet_rpc is not None:
211-
await self.wallet_rpc.await_closed()
191+
async with DataStore.managed(database=self.db_path, sql_log_path=sql_log_path) as self._data_store:
192+
self._wallet_rpc = await self.wallet_rpc_init
193+
194+
self.periodically_manage_data_task = asyncio.create_task(self.periodically_manage_data())
195+
try:
196+
yield
197+
finally:
198+
# TODO: review for anything else we need to do here
199+
self._shut_down = True
200+
if self._wallet_rpc is not None:
201+
self.wallet_rpc.close()
202+
203+
if self.periodically_manage_data_task is not None:
204+
try:
205+
self.periodically_manage_data_task.cancel()
206+
except asyncio.CancelledError:
207+
pass
208+
if self._wallet_rpc is not None:
209+
await self.wallet_rpc.await_closed()
212210

213211
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
214212
self.state_changed_callback = callback

chia/data_layer/data_store.py

Lines changed: 90 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import logging
45
from collections import defaultdict
56
from contextlib import asynccontextmanager
@@ -48,10 +49,11 @@ class DataStore:
4849
db_wrapper: DBWrapper2
4950

5051
@classmethod
51-
async def create(
52+
@contextlib.asynccontextmanager
53+
async def managed(
5254
cls, database: Union[str, Path], uri: bool = False, sql_log_path: Optional[Path] = None
53-
) -> DataStore:
54-
db_wrapper = await DBWrapper2.create(
55+
) -> AsyncIterator[DataStore]:
56+
async with DBWrapper2.managed(
5557
database=database,
5658
uri=uri,
5759
journal_mode="WAL",
@@ -63,100 +65,97 @@ async def create(
6365
foreign_keys=True,
6466
row_factory=aiosqlite.Row,
6567
log_path=sql_log_path,
66-
)
67-
self = cls(db_wrapper=db_wrapper)
68+
) as db_wrapper:
69+
self = cls(db_wrapper=db_wrapper)
6870

69-
async with db_wrapper.writer() as writer:
70-
await writer.execute(
71-
f"""
72-
CREATE TABLE IF NOT EXISTS node(
73-
hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32),
74-
node_type INTEGER NOT NULL CHECK(
75-
(
76-
node_type == {int(NodeType.INTERNAL)}
77-
AND left IS NOT NULL
78-
AND right IS NOT NULL
79-
AND key IS NULL
80-
AND value IS NULL
81-
)
82-
OR
83-
(
84-
node_type == {int(NodeType.TERMINAL)}
85-
AND left IS NULL
86-
AND right IS NULL
87-
AND key IS NOT NULL
88-
AND value IS NOT NULL
89-
)
90-
),
91-
left BLOB REFERENCES node,
92-
right BLOB REFERENCES node,
93-
key BLOB,
94-
value BLOB
71+
async with db_wrapper.writer() as writer:
72+
await writer.execute(
73+
f"""
74+
CREATE TABLE IF NOT EXISTS node(
75+
hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32),
76+
node_type INTEGER NOT NULL CHECK(
77+
(
78+
node_type == {int(NodeType.INTERNAL)}
79+
AND left IS NOT NULL
80+
AND right IS NOT NULL
81+
AND key IS NULL
82+
AND value IS NULL
83+
)
84+
OR
85+
(
86+
node_type == {int(NodeType.TERMINAL)}
87+
AND left IS NULL
88+
AND right IS NULL
89+
AND key IS NOT NULL
90+
AND value IS NOT NULL
91+
)
92+
),
93+
left BLOB REFERENCES node,
94+
right BLOB REFERENCES node,
95+
key BLOB,
96+
value BLOB
97+
)
98+
"""
9599
)
96-
"""
97-
)
98-
await writer.execute(
99-
"""
100-
CREATE TRIGGER IF NOT EXISTS no_node_updates
101-
BEFORE UPDATE ON node
102-
BEGIN
103-
SELECT RAISE(FAIL, 'updates not allowed to the node table');
104-
END
105-
"""
106-
)
107-
await writer.execute(
108-
f"""
109-
CREATE TABLE IF NOT EXISTS root(
110-
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
111-
generation INTEGER NOT NULL CHECK(generation >= 0),
112-
node_hash BLOB,
113-
status INTEGER NOT NULL CHECK(
114-
{" OR ".join(f"status == {status}" for status in Status)}
115-
),
116-
PRIMARY KEY(tree_id, generation),
117-
FOREIGN KEY(node_hash) REFERENCES node(hash)
100+
await writer.execute(
101+
"""
102+
CREATE TRIGGER IF NOT EXISTS no_node_updates
103+
BEFORE UPDATE ON node
104+
BEGIN
105+
SELECT RAISE(FAIL, 'updates not allowed to the node table');
106+
END
107+
"""
118108
)
119-
"""
120-
)
121-
# TODO: Add ancestor -> hash relationship, this might involve temporarily
122-
# deferring the foreign key enforcement due to the insertion order
123-
# and the node table also enforcing a similar relationship in the
124-
# other direction.
125-
# FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor)
126-
await writer.execute(
127-
"""
128-
CREATE TABLE IF NOT EXISTS ancestors(
129-
hash BLOB NOT NULL REFERENCES node,
130-
ancestor BLOB CHECK(length(ancestor) == 32),
131-
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
132-
generation INTEGER NOT NULL,
133-
PRIMARY KEY(hash, tree_id, generation),
134-
FOREIGN KEY(ancestor) REFERENCES node(hash)
109+
await writer.execute(
110+
f"""
111+
CREATE TABLE IF NOT EXISTS root(
112+
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
113+
generation INTEGER NOT NULL CHECK(generation >= 0),
114+
node_hash BLOB,
115+
status INTEGER NOT NULL CHECK(
116+
{" OR ".join(f"status == {status}" for status in Status)}
117+
),
118+
PRIMARY KEY(tree_id, generation),
119+
FOREIGN KEY(node_hash) REFERENCES node(hash)
120+
)
121+
"""
135122
)
136-
"""
137-
)
138-
await writer.execute(
139-
"""
140-
CREATE TABLE IF NOT EXISTS subscriptions(
141-
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
142-
url TEXT,
143-
ignore_till INTEGER,
144-
num_consecutive_failures INTEGER,
145-
from_wallet tinyint CHECK(from_wallet == 0 OR from_wallet == 1),
146-
PRIMARY KEY(tree_id, url)
123+
# TODO: Add ancestor -> hash relationship, this might involve temporarily
124+
# deferring the foreign key enforcement due to the insertion order
125+
# and the node table also enforcing a similar relationship in the
126+
# other direction.
127+
# FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor)
128+
await writer.execute(
129+
"""
130+
CREATE TABLE IF NOT EXISTS ancestors(
131+
hash BLOB NOT NULL REFERENCES node,
132+
ancestor BLOB CHECK(length(ancestor) == 32),
133+
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
134+
generation INTEGER NOT NULL,
135+
PRIMARY KEY(hash, tree_id, generation),
136+
FOREIGN KEY(ancestor) REFERENCES node(hash)
137+
)
138+
"""
139+
)
140+
await writer.execute(
141+
"""
142+
CREATE TABLE IF NOT EXISTS subscriptions(
143+
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
144+
url TEXT,
145+
ignore_till INTEGER,
146+
num_consecutive_failures INTEGER,
147+
from_wallet tinyint CHECK(from_wallet == 0 OR from_wallet == 1),
148+
PRIMARY KEY(tree_id, url)
149+
)
150+
"""
151+
)
152+
await writer.execute(
153+
"""
154+
CREATE INDEX IF NOT EXISTS node_hash ON root(node_hash)
155+
"""
147156
)
148-
"""
149-
)
150-
await writer.execute(
151-
"""
152-
CREATE INDEX IF NOT EXISTS node_hash ON root(node_hash)
153-
"""
154-
)
155-
156-
return self
157157

158-
async def close(self) -> None:
159-
await self.db_wrapper.close()
158+
yield self
160159

161160
@asynccontextmanager
162161
async def transaction(self) -> AsyncIterator[None]:

chia/data_layer/util/benchmark.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
2222
if os.path.exists(db_path):
2323
os.remove(db_path)
2424

25-
data_store = await DataStore.create(database=db_path)
26-
try:
25+
async with DataStore.managed(database=db_path) as data_store:
2726
hint_keys_values: Dict[bytes, bytes] = {}
2827

2928
tree_id = bytes32(b"0" * 32)
@@ -107,8 +106,6 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
107106
print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}")
108107
root = await data_store.get_tree_root(tree_id=tree_id)
109108
print(f"Root hash: {root.node_hash}")
110-
finally:
111-
await data_store.close()
112109

113110

114111
if __name__ == "__main__":

tests/core/data_layer/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def tree_id_fixture() -> bytes32:
6161

6262
@pytest.fixture(name="raw_data_store", scope="function")
6363
async def raw_data_store_fixture(database_uri: str) -> AsyncIterable[DataStore]:
64-
store = await DataStore.create(database=database_uri, uri=True)
65-
yield store
66-
await store.close()
64+
async with DataStore.managed(database=database_uri, uri=True) as store:
65+
yield store
6766

6867

6968
@pytest.fixture(name="data_store", scope="function")

tests/core/data_layer/test_data_store.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,11 @@ async def test_create_creates_tables_and_columns(
8888
columns = await cursor.fetchall()
8989
assert columns == []
9090

91-
store = await DataStore.create(database=database_uri, uri=True)
92-
try:
91+
async with DataStore.managed(database=database_uri, uri=True):
9392
async with db_wrapper.reader() as reader:
9493
cursor = await reader.execute(query)
9594
columns = await cursor.fetchall()
9695
assert [column[1] for column in columns] == expected_columns
97-
finally:
98-
await store.close()
9996

10097

10198
@pytest.mark.anyio
@@ -379,8 +376,7 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
379376
saved_batches: List[List[Dict[str, Any]]] = []
380377

381378
db_uri = generate_in_memory_db_uri()
382-
single_op_data_store = await DataStore.create(database=db_uri, uri=True)
383-
try:
379+
async with DataStore.managed(database=db_uri, uri=True) as single_op_data_store:
384380
await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED)
385381
random = Random()
386382
random.seed(100, version=2)
@@ -423,8 +419,6 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
423419
batch = []
424420
root = await single_op_data_store.get_tree_root(tree_id=tree_id)
425421
saved_roots.append(root)
426-
finally:
427-
await single_op_data_store.close()
428422

429423
for batch_number, batch in enumerate(saved_batches):
430424
assert len(batch) == num_ops_per_batch
@@ -1265,8 +1259,7 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d
12651259
num_ops_per_batch = 100
12661260

12671261
db_uri = generate_in_memory_db_uri()
1268-
data_store_server = await DataStore.create(database=db_uri, uri=True)
1269-
try:
1262+
async with DataStore.managed(database=db_uri, uri=True) as data_store_server:
12701263
await data_store_server.create_tree(tree_id, status=Status.COMMITTED)
12711264
random = Random()
12721265
random.seed(100, version=2)
@@ -1291,8 +1284,6 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d
12911284
root = await data_store_server.get_tree_root(tree_id)
12921285
await write_files_for_root(data_store_server, tree_id, root, tmp_path, 0)
12931286
roots.append(root)
1294-
finally:
1295-
await data_store_server.close()
12961287

12971288
generation = 1
12981289
assert len(roots) == num_batches

0 commit comments

Comments
 (0)