Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/block_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from chia.full_node.coin_store import CoinStore
from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.util.db_version import lookup_db_version
from chia.util.db_wrapper import DBWrapper2
from chia.util.db_wrapper import DBWrapper2, Writer

# the first transaction block. Each byte in transaction_height_delta is the
# number of blocks to skip forward to get to the next transaction block
Expand Down Expand Up @@ -59,7 +59,7 @@ async def main(db_path: Path) -> None:
await connection.execute("pragma query_only=ON")
db_version: int = await lookup_db_version(connection)

db_wrapper = DBWrapper2(connection, db_version=db_version)
db_wrapper = DBWrapper2(Writer(_connection=connection), db_version=db_version)
await db_wrapper.add_connection(await aiosqlite.connect(db_path))

block_store = await BlockStore.create(db_wrapper)
Expand Down
23 changes: 15 additions & 8 deletions chia/_tests/db/test_db_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import contextlib
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union

import aiosqlite
import pytest
Expand All @@ -13,11 +13,18 @@

from chia._tests.util.db_connection import DBConnection, PathDBConnection
from chia._tests.util.misc import Marks, boolean_datacases, datacases
from chia.util.db_wrapper import DBWrapper2, ForeignKeyError, InternalError, NestedForeignKeyDelayedRequestError
from chia.util.db_wrapper import (
DBWrapper2,
ForeignKeyError,
InternalError,
NestedForeignKeyDelayedRequestError,
Reader,
Writer,
)
from chia.util.task_referencer import create_referenced_task

if TYPE_CHECKING:
ConnectionContextManager = contextlib.AbstractAsyncContextManager[aiosqlite.core.Connection]
ConnectionContextManager = contextlib.AbstractAsyncContextManager[Reader]
GetReaderMethod = Callable[[DBWrapper2], Callable[[], ConnectionContextManager]]


Expand Down Expand Up @@ -78,7 +85,7 @@ async def get_value(cursor: aiosqlite.Cursor) -> int:
return int(row[0])


async def query_value(connection: aiosqlite.Connection) -> int:
async def query_value(connection: Union[Reader, Writer]) -> int:
async with connection.execute("SELECT value FROM counter") as cursor:
return await get_value(cursor=cursor)

Expand Down Expand Up @@ -222,7 +229,7 @@ async def test_readers_nests_writer(get_reader_method: GetReaderMethod) -> None:

async with db_wrapper.writer_maybe_transaction() as conn1:
async with get_reader_method(db_wrapper)() as conn2:
assert conn1 == conn2
assert conn1._connection == conn2._connection
async with db_wrapper.writer_maybe_transaction() as conn3:
assert conn1 == conn3
async with conn3.execute("SELECT value FROM counter") as cursor:
Expand All @@ -246,7 +253,7 @@ async def test_only_transactioned_reader_ignores_writer(transactioned: bool) ->
async def write() -> None:
try:
async with db_wrapper.writer() as writer:
assert reader is not writer
assert reader._connection is not writer._connection

await writer.execute("UPDATE counter SET value = 1")
finally:
Expand Down Expand Up @@ -298,7 +305,7 @@ async def test_writer_in_reader_works() -> None:

async with db_wrapper.reader() as reader:
async with db_wrapper.writer() as writer:
assert writer is not reader
assert writer._connection is not reader._connection
await writer.execute("UPDATE counter SET value = 1")
assert await query_value(connection=writer) == 1
assert await query_value(connection=reader) == 0
Expand All @@ -313,7 +320,7 @@ async def test_reader_transaction_is_deferred() -> None:

async with db_wrapper.reader() as reader:
async with db_wrapper.writer() as writer:
assert writer is not reader
assert writer._connection is not reader._connection
await writer.execute("UPDATE counter SET value = 1")
assert await query_value(connection=writer) == 1

Expand Down
2 changes: 0 additions & 2 deletions chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, Optional, cast
from unittest.mock import patch

import aiosqlite
import pytest
from chia_rs import G1Element, G2Element
from chia_rs.sized_bytes import bytes32
Expand Down Expand Up @@ -2483,7 +2482,6 @@ async def test_set_wallet_resync_schema(wallet_rpc_environment: WalletRpcTestEnv
"Schema has been changed, reset sync db won't work, please update WalletNode.reset_sync_db function"
)
dbw: DBWrapper2 = wallet_node.wallet_state_manager.db_wrapper
conn: aiosqlite.Connection
async with dbw.writer() as conn:
await conn.execute("CREATE TABLE blah(temp int)")
await wallet_node.reset_sync_db(db_path, fingerprint)
Expand Down
4 changes: 2 additions & 2 deletions chia/_tests/wallet/test_sign_coin_spends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.types.coin_spend import CoinSpend, make_spend
from chia.types.condition_opcodes import ConditionOpcode
from chia.util.db_wrapper import DBWrapper2, manage_connection
from chia.util.db_wrapper import DBWrapper2, Writer, manage_connection
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
from chia.wallet.puzzles.p2_delegated_puzzle_or_hidden_puzzle import (
Expand Down Expand Up @@ -70,7 +70,7 @@ async def test_wsm_sign_transaction() -> None:
async with manage_connection("file:temp.db?mode=memory&cache=shared", uri=True, name="writer") as writer_conn:
async with manage_connection("file:temp.db?mode=memory&cache=shared", uri=True, name="reader") as reader_conn:
wsm = WalletStateManager()
db = DBWrapper2(writer_conn)
db = DBWrapper2(Writer(_connection=writer_conn))
await db.add_connection(reader_conn)
wsm.puzzle_store = await WalletPuzzleStore.create(db)
wsm.constants = DEFAULT_CONSTANTS
Expand Down
6 changes: 3 additions & 3 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
unspecified,
)
from chia.types.blockchain_format.program import Program
from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2
from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2, Reader, Writer

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -756,7 +756,7 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3

async def get_keys_values_cursor(
self,
reader: aiosqlite.Connection,
reader: Reader,
root_hash: Optional[bytes32],
only_keys: bool = False,
) -> aiosqlite.Cursor:
Expand Down Expand Up @@ -1404,7 +1404,7 @@ async def upsert(

return InsertResult(node_hash=new_terminal_node_hash, root=new_root)

async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None) -> None:
async def clean_node_table(self, writer: Optional[Writer] = None) -> None:
query = """
WITH RECURSIVE pending_nodes AS (
SELECT node_hash AS hash FROM root
Expand Down
8 changes: 3 additions & 5 deletions chia/util/action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from dataclasses import dataclass, field
from typing import Callable, Generic, Optional, Protocol, TypeVar, final

import aiosqlite

from chia.util.db_wrapper import DBWrapper2, execute_fetchone
from chia.util.db_wrapper import DBWrapper2, Writer, execute_fetchone


class ResourceManager(Protocol):
Expand All @@ -30,9 +28,9 @@ async def save_resource(self, resource: SideEffects) -> None: ...
@dataclass
class SQLiteResourceManager:
_db: DBWrapper2
_active_writer: Optional[aiosqlite.Connection] = field(init=False, default=None)
_active_writer: Optional[Writer] = field(init=False, default=None)

def get_active_writer(self) -> aiosqlite.Connection:
def get_active_writer(self) -> Writer:
if self._active_writer is None:
raise RuntimeError("Can only access resources while under `use()` context manager")

Expand Down
6 changes: 4 additions & 2 deletions chia/util/db_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import aiosqlite

from chia.util.db_wrapper import Writer


async def lookup_db_version(db: aiosqlite.Connection) -> int:
try:
Expand All @@ -20,10 +22,10 @@ async def lookup_db_version(db: aiosqlite.Connection) -> int:
return 1


async def set_db_version_async(db: aiosqlite.Connection, version: int) -> None:
async def set_db_version_async(db: Writer, version: int) -> None:
await db.execute("CREATE TABLE database_version(version int)")
await db.execute("INSERT INTO database_version VALUES (?)", (version,))
await db.commit()
await db._connection.commit()


def set_db_version(db: sqlite3.Connection, version: int) -> None:
Expand Down
Loading
Loading