Skip to content

Commit 3caa075

Browse files
authored
add DBWrapper2.managed() (#16880)
1 parent 8e3f9db commit 3caa075

19 files changed

+498
-514
lines changed

benchmarks/block_store.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from chia.types.blockchain_format.sized_bytes import bytes32
3131
from chia.types.blockchain_format.sub_epoch_summary import SubEpochSummary
3232
from chia.types.full_block import FullBlock
33-
from chia.util.db_wrapper import DBWrapper2
3433
from chia.util.ints import uint8, uint32, uint64, uint128
3534

3635
# to run this benchmark:
@@ -44,7 +43,6 @@
4443

4544
async def run_add_block_benchmark(version: int) -> None:
4645
verbose: bool = "--verbose" in sys.argv
47-
db_wrapper: DBWrapper2 = await setup_db("block-store-benchmark.db", version)
4846

4947
# keep track of benchmark total time
5048
all_test_time = 0.0
@@ -54,7 +52,7 @@ async def run_add_block_benchmark(version: int) -> None:
5452

5553
header_hashes = []
5654

57-
try:
55+
async with setup_db("block-store-benchmark.db", version) as db_wrapper:
5856
block_store = await BlockStore.create(db_wrapper)
5957

6058
block_height = 1
@@ -495,9 +493,6 @@ async def run_add_block_benchmark(version: int) -> None:
495493
db_size = os.path.getsize(Path("block-store-benchmark.db"))
496494
print(f"database size: {db_size/1000000:.3f} MB")
497495

498-
finally:
499-
await db_wrapper.close()
500-
501496

502497
if __name__ == "__main__":
503498
print("version 2")

benchmarks/coin_store.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from chia.full_node.coin_store import CoinStore
1313
from chia.types.blockchain_format.coin import Coin
1414
from chia.types.blockchain_format.sized_bytes import bytes32
15-
from chia.util.db_wrapper import DBWrapper2
1615
from chia.util.ints import uint32, uint64
1716

1817
# to run this benchmark:
@@ -41,12 +40,11 @@ def make_coins(num: int) -> Tuple[List[Coin], List[bytes32]]:
4140

4241
async def run_new_block_benchmark(version: int) -> None:
4342
verbose: bool = "--verbose" in sys.argv
44-
db_wrapper: DBWrapper2 = await setup_db("coin-store-benchmark.db", version)
4543

4644
# keep track of benchmark total time
4745
all_test_time = 0.0
4846

49-
try:
47+
async with setup_db("coin-store-benchmark.db", version) as db_wrapper:
5048
coin_store = await CoinStore.create(db_wrapper)
5149

5250
all_unspent: List[bytes32] = []
@@ -301,9 +299,6 @@ async def run_new_block_benchmark(version: int) -> None:
301299
all_test_time += total_time
302300
print(f"all tests completed in {all_test_time:0.4f}s")
303301

304-
finally:
305-
await db_wrapper.close()
306-
307302
db_size = os.path.getsize(Path("coin-store-benchmark.db"))
308303
print(f"database size: {db_size/1000000:.3f} MB")
309304

benchmarks/utils.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import enum
45
import os
56
import random
67
import subprocess
78
import sys
8-
from datetime import datetime
99
from pathlib import Path
10-
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
10+
from typing import Any, AsyncIterator, Generic, Optional, Tuple, Type, TypeVar, Union
1111

12-
import aiosqlite
1312
import click
1413
from chia_rs import AugSchemeMPL, G1Element, G2Element
1514

@@ -179,30 +178,29 @@ def rand_full_block() -> FullBlock:
179178
return full_block
180179

181180

182-
async def setup_db(name: Union[str, os.PathLike[str]], db_version: int) -> DBWrapper2:
181+
@contextlib.asynccontextmanager
182+
async def setup_db(name: Union[str, os.PathLike[str]], db_version: int) -> AsyncIterator[DBWrapper2]:
183183
db_filename = Path(name)
184184
try:
185185
os.unlink(db_filename)
186186
except FileNotFoundError:
187187
pass
188-
connection = await aiosqlite.connect(db_filename)
189-
190-
def sql_trace_callback(req: str) -> None:
191-
sql_log_path = "sql.log"
192-
timestamp = datetime.now().strftime("%H:%M:%S.%f")
193-
log = open(sql_log_path, "a")
194-
log.write(timestamp + " " + req + "\n")
195-
log.close()
196188

189+
log_path: Optional[Path]
197190
if "--sql-logging" in sys.argv:
198-
await connection.set_trace_callback(sql_trace_callback)
199-
200-
await connection.execute("pragma journal_mode=wal")
201-
await connection.execute("pragma synchronous=full")
202-
203-
ret = DBWrapper2(connection, db_version)
204-
await ret.add_connection(await aiosqlite.connect(db_filename))
205-
return ret
191+
log_path = Path("sql.log")
192+
else:
193+
log_path = None
194+
195+
async with DBWrapper2.managed(
196+
database=db_filename,
197+
log_path=log_path,
198+
db_version=db_version,
199+
reader_count=1,
200+
journal_mode="wal",
201+
synchronous="full",
202+
) as db_wrapper:
203+
yield db_wrapper
206204

207205

208206
def get_commit_hash() -> str:

chia/cmds/check_wallet_db.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -358,18 +358,18 @@ def check_unexpected_derivation_entries(
358358

359359
async def scan(self, db_path: Path) -> int:
360360
"""Returns number of lines of error output (not warnings)"""
361-
self.db_wrapper = await DBWrapper2.create(
361+
async with DBWrapper2.managed(
362362
database=db_path,
363363
reader_count=self.config.get("db_readers", 4),
364364
log_path=self.sql_log_path,
365365
synchronous=db_synchronous_on("auto"),
366-
)
367-
# TODO: Pass down db_wrapper
368-
wallets = await self.get_all_wallets()
369-
derivation_paths = await self.get_derivation_paths()
370-
errors = []
371-
warnings = []
372-
try:
366+
) as self.db_wrapper:
367+
# TODO: Pass down db_wrapper
368+
wallets = await self.get_all_wallets()
369+
derivation_paths = await self.get_derivation_paths()
370+
errors = []
371+
warnings = []
372+
373373
if self.verbose:
374374
await self.show_tables()
375375
print_min_max_derivation_for_wallets(derivation_paths)
@@ -387,8 +387,7 @@ async def scan(self, db_path: Path) -> int:
387387
if len(errors) > 0:
388388
print(f" ---- Errors Found for {db_path.name}----")
389389
print("\n".join(errors))
390-
finally:
391-
await self.db_wrapper.close()
390+
392391
return len(errors)
393392

394393

0 commit comments

Comments
 (0)