Skip to content

Commit 3c9e897

Browse files
committed
first crack at async context manager for ConsensusStore
1 parent c286d0d commit 3c9e897

File tree

3 files changed

+168
-132
lines changed

3 files changed

+168
-132
lines changed

chia/consensus/blockchain.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -538,56 +538,57 @@ async def _reconsider_peak(
538538
else:
539539
records_to_add = await self.consensus_store.get_block_records_by_hash(fork_info.block_hashes)
540540

541-
for fetched_block_record in records_to_add:
542-
if not fetched_block_record.is_transaction_block:
543-
# Coins are only created in TX blocks so there are no state updates for this block
544-
continue
545-
546-
height = fetched_block_record.height
547-
# We need to recompute the additions and removals, since they are
548-
# not stored on DB. We have all the additions and removals in the
549-
# fork_info object, we just need to pick the ones belonging to each
550-
# individual block height
551-
552-
# Apply the coin store changes for each block that is now in the blockchain
553-
included_reward_coins = [
554-
fork_add.coin
555-
for fork_add in fork_info.additions_since_fork.values()
556-
if fork_add.confirmed_height == height and fork_add.is_coinbase
557-
]
558-
tx_additions = [
559-
(coin_id, fork_add.coin, fork_add.same_as_parent)
560-
for coin_id, fork_add in fork_info.additions_since_fork.items()
561-
if fork_add.confirmed_height == height and not fork_add.is_coinbase
562-
]
563-
tx_removals = [
564-
coin_id for coin_id, fork_rem in fork_info.removals_since_fork.items() if fork_rem.height == height
565-
]
566-
assert fetched_block_record.timestamp is not None
567-
await self.consensus_store.new_block(
568-
height,
569-
fetched_block_record.timestamp,
570-
included_reward_coins,
571-
tx_additions,
572-
tx_removals,
573-
)
574-
if self._log_coins and (len(tx_removals) > 0 or len(tx_additions) > 0):
575-
log.info(
576-
f"adding new block to coin_store "
577-
f"(hh: {fetched_block_record.header_hash} "
578-
f"height: {fetched_block_record.height}), {len(tx_removals)} spends"
541+
async with self.consensus_store as writer:
542+
for fetched_block_record in records_to_add:
543+
if not fetched_block_record.is_transaction_block:
544+
# Coins are only created in TX blocks so there are no state updates for this block
545+
continue
546+
547+
height = fetched_block_record.height
548+
# We need to recompute the additions and removals, since they are
549+
# not stored on DB. We have all the additions and removals in the
550+
# fork_info object, we just need to pick the ones belonging to each
551+
# individual block height
552+
553+
# Apply the coin store changes for each block that is now in the blockchain
554+
included_reward_coins = [
555+
fork_add.coin
556+
for fork_add in fork_info.additions_since_fork.values()
557+
if fork_add.confirmed_height == height and fork_add.is_coinbase
558+
]
559+
tx_additions = [
560+
(coin_id, fork_add.coin, fork_add.same_as_parent)
561+
for coin_id, fork_add in fork_info.additions_since_fork.items()
562+
if fork_add.confirmed_height == height and not fork_add.is_coinbase
563+
]
564+
tx_removals = [
565+
coin_id for coin_id, fork_rem in fork_info.removals_since_fork.items() if fork_rem.height == height
566+
]
567+
assert fetched_block_record.timestamp is not None
568+
await writer.new_block(
569+
height,
570+
fetched_block_record.timestamp,
571+
included_reward_coins,
572+
tx_additions,
573+
tx_removals,
579574
)
580-
log.info("rewards: %s", ",".join([add.name().hex()[0:6] for add in included_reward_coins]))
581-
log.info("additions: %s", ",".join([add[0].hex()[0:6] for add in tx_additions]))
582-
log.info("removals: %s", ",".join([f"{rem}"[0:6] for rem in tx_removals]))
575+
if self._log_coins and (len(tx_removals) > 0 or len(tx_additions) > 0):
576+
log.info(
577+
f"adding new block to coin_store "
578+
f"(hh: {fetched_block_record.header_hash} "
579+
f"height: {fetched_block_record.height}), {len(tx_removals)} spends"
580+
)
581+
log.info("rewards: %s", ",".join([add.name().hex()[0:6] for add in included_reward_coins]))
582+
log.info("additions: %s", ",".join([add[0].hex()[0:6] for add in tx_additions]))
583+
log.info("removals: %s", ",".join([f"{rem}"[0:6] for rem in tx_removals]))
583584

584-
# we made it to the end successfully
585-
# Rollback sub_epoch_summaries
586-
await self.consensus_store.rollback(fork_info.fork_height)
587-
await self.consensus_store.set_in_chain([(br.header_hash,) for br in records_to_add])
585+
# we made it to the end successfully
586+
# Rollback sub_epoch_summaries
587+
await writer.rollback(fork_info.fork_height)
588+
await writer.set_in_chain([(br.header_hash,) for br in records_to_add])
588589

589-
# Changes the peak to be the new peak
590-
await self.consensus_store.set_peak(block_record.header_hash)
590+
# Changes the peak to be the new peak
591+
await writer.set_peak(block_record.header_hash)
591592

592593
return records_to_add, StateChangeSummary(
593594
block_record,

chia/consensus/consensus_store_protocol.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection
4-
from contextlib import AbstractAsyncContextManager
54
from typing import Optional, Protocol
65

7-
import aiosqlite
86
from chia_rs import BlockRecord, FullBlock, SubEpochChallengeSegment, SubEpochSummary
97
from chia_rs.sized_bytes import bytes32
108
from chia_rs.sized_ints import uint32, uint64
@@ -13,54 +11,82 @@
1311
from chia.types.coin_record import CoinRecord
1412

1513

14+
class ConsensusStoreWriteProtocol(Protocol):
15+
"""
16+
Protocol for performing mutating operations on the consensus store.
17+
18+
Instances implementing this protocol should be acquired via the async
19+
context manager on ConsensusStoreProtocol to ensure atomic write
20+
operations (e.g., wrapping all writes in a single DB transaction).
21+
"""
22+
23+
# Block store writes
24+
async def add_full_block(self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord) -> None: ...
25+
def rollback_cache_block(self, header_hash: bytes32) -> None: ...
26+
async def rollback(self, height: int) -> None: ...
27+
async def set_in_chain(self, header_hashes: list[tuple[bytes32]]) -> None: ...
28+
async def set_peak(self, header_hash: bytes32) -> None: ...
29+
async def persist_sub_epoch_challenge_segments(
30+
self, ses_block_hash: bytes32, segments: list[SubEpochChallengeSegment]
31+
) -> None: ...
32+
33+
# Coin store writes
34+
async def rollback_to_block(self, block_index: int) -> dict[bytes32, CoinRecord]: ...
35+
async def new_block(
36+
self,
37+
height: uint32,
38+
timestamp: uint64,
39+
included_reward_coins: Collection[Coin],
40+
tx_additions: Collection[tuple[bytes32, Coin, bool]],
41+
tx_removals: list[bytes32],
42+
) -> None: ...
43+
44+
1645
class ConsensusStoreProtocol(Protocol):
1746
"""
18-
Protocol for the consensus store, which provides methods to interact with
19-
the consensus-related data in the blockchain.
47+
Read-only protocol for the consensus store.
48+
49+
This protocol also acts as an async context manager. Entering the context
50+
yields a ConsensusStoreWriteProtocol instance, which must be used for
51+
performing write (mutating) operations. This ensures atomic writes and
52+
makes it harder to accidentally perform writes outside a transaction.
53+
54+
Example usage:
55+
async with store as writer:
56+
await writer.add_full_block(...)
57+
await writer.set_peak(...)
58+
59+
# Outside the context, only read methods are available
60+
br = await store.get_block_record(header_hash)
2061
"""
2162

22-
# Block store methods
23-
def transaction(self) -> AbstractAsyncContextManager[aiosqlite.Connection]: ...
63+
# Async context manager methods
64+
async def __aenter__(self) -> ConsensusStoreWriteProtocol: ...
65+
async def __aexit__(self, exc_type, exc, tb) -> Optional[bool]: ...
2466

67+
# Block store reads
2568
async def get_block_records_close_to_peak(
2669
self, blocks_n: int
2770
) -> tuple[dict[bytes32, BlockRecord], Optional[bytes32]]: ...
2871
async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]: ...
29-
async def add_full_block(self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord) -> None: ...
30-
def rollback_cache_block(self, header_hash: bytes32) -> None: ...
3172
async def get_block_records_by_hash(self, header_hashes: list[bytes32]) -> list[BlockRecord]: ...
32-
async def rollback(self, height: int) -> None: ...
33-
async def set_in_chain(self, header_hashes: list[tuple[bytes32]]) -> None: ...
34-
async def set_peak(self, header_hash: bytes32) -> None: ...
3573
async def get_block_records_in_range(self, start: int, stop: int) -> dict[bytes32, BlockRecord]: ...
3674
def get_block_from_cache(self, header_hash: bytes32) -> Optional[FullBlock]: ...
3775
async def get_blocks_by_hash(self, header_hashes: list[bytes32]) -> list[FullBlock]: ...
3876
async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]: ...
3977
async def get_prev_hash(self, header_hash: bytes32) -> bytes32: ...
40-
async def persist_sub_epoch_challenge_segments(
41-
self, ses_block_hash: bytes32, segments: list[SubEpochChallengeSegment]
42-
) -> None: ...
4378
async def get_sub_epoch_challenge_segments(
4479
self, ses_block_hash: bytes32
4580
) -> Optional[list[SubEpochChallengeSegment]]: ...
4681
async def get_generator(self, header_hash: bytes32) -> Optional[bytes]: ...
4782
async def get_generators_at(self, heights: set[uint32]) -> dict[uint32, bytes]: ...
4883

49-
# Coin store methods
84+
# Coin store reads
5085
async def get_coin_records(self, names: Collection[bytes32]) -> list[CoinRecord]: ...
51-
async def rollback_to_block(self, block_index: int) -> dict[bytes32, CoinRecord]: ...
52-
async def new_block(
53-
self,
54-
height: uint32,
55-
timestamp: uint64,
56-
included_reward_coins: Collection[Coin],
57-
tx_additions: Collection[tuple[bytes32, Coin, bool]],
58-
tx_removals: list[bytes32],
59-
) -> None: ...
6086
async def get_coins_added_at_height(self, height: uint32) -> list[CoinRecord]: ...
6187
async def get_coins_removed_at_height(self, height: uint32) -> list[CoinRecord]: ...
6288

63-
# Height map methods
89+
# Height map methods (kept here for now; non-async and maybe_flush remain on read protocol)
6490
def get_ses_heights(self) -> list[uint32]: ...
6591
def get_ses(self, height: uint32) -> SubEpochSummary: ...
6692
def contains_height(self, height: uint32) -> bool: ...

0 commit comments

Comments
 (0)