Skip to content

Commit 1eb36b8

Browse files
committed
Handle nested context manager
1 parent cde38c4 commit 1eb36b8

File tree

5 files changed

+61
-42
lines changed

5 files changed

+61
-42
lines changed

chia/_tests/blockchain/blockchain_test_utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,18 @@
1717
async def check_block_store_invariant(bc: Blockchain):
1818
in_chain = set()
1919
max_height = -1
20-
async with bc.consensus_store.transaction() as conn:
21-
async with conn.execute("SELECT height, in_main_chain FROM full_blocks") as cursor:
22-
rows = await cursor.fetchall()
23-
for row in rows:
24-
height = row[0]
25-
26-
# if this block is in-chain, ensure we haven't found another block
27-
# at this height that's also in chain. That would be an invariant
28-
# violation
29-
if row[1]:
30-
# make sure we don't have any duplicate heights. Each block
31-
# height can only have a single block with in_main_chain set
32-
assert height not in in_chain
33-
in_chain.add(height)
34-
max_height = max(max_height, height)
35-
36-
# make sure every height is represented in the set
37-
assert len(in_chain) == max_height + 1
20+
async for height in bc.consensus_store.get_block_heights_in_main_chain():
21+
# if this block is in-chain, ensure we haven't found another block
22+
# at this height that's also in chain. That would be an invariant
23+
# violation
24+
# make sure we don't have any duplicate heights. Each block
25+
# height can only have a single block with in_main_chain set
26+
assert height not in in_chain
27+
in_chain.add(height)
28+
max_height = max(max_height, height)
29+
30+
# make sure every height is represented in the set
31+
assert len(in_chain) == max_height + 1
3832

3933

4034
async def _validate_and_add_block(

chia/_tests/core/full_node/stores/test_block_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def rand_vdf_proof() -> VDFProof:
410410

411411
# make sure we get the same result when we hit the database
412412
# itself (and not just the block cache)
413-
block_store.rollback_cache_block(block.header_hash)
413+
consensus_store.rollback_cache_block(block.header_hash)
414414
b = await block_store.get_full_block(block.header_hash)
415415
assert b is not None
416416
assert b.challenge_chain_ip_proof == proof

chia/consensus/blockchain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ async def add_block(
413413

414414
try:
415415
# Always add the block to the database
416-
async with self.consensus_store.transaction():
416+
async with self.consensus_store as writer:
417417
# Perform the DB operations to update the state, and rollback if something goes wrong
418-
await self.consensus_store.add_full_block(header_hash, block, block_record)
418+
await writer.add_full_block(header_hash, block, block_record)
419419
records, state_change_summary = await self._reconsider_peak(block_record, genesis, fork_info)
420420

421421
# Then update the memory cache. It is important that this is not cancelled and does not throw

chia/consensus/consensus_store_protocol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection
4-
from typing import Optional, Protocol
4+
from typing import AsyncIterator, Optional, Protocol
55

66
from chia_rs import BlockRecord, FullBlock, SubEpochChallengeSegment, SubEpochSummary
77
from chia_rs.sized_bytes import bytes32
@@ -22,7 +22,6 @@ class ConsensusStoreWriteProtocol(Protocol):
2222

2323
# Block store writes
2424
async def add_full_block(self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord) -> None: ...
25-
def rollback_cache_block(self, header_hash: bytes32) -> None: ...
2625
async def rollback(self, height: int) -> None: ...
2726
async def set_in_chain(self, header_hashes: list[tuple[bytes32]]) -> None: ...
2827
async def set_peak(self, header_hash: bytes32) -> None: ...
@@ -80,6 +79,7 @@ async def get_sub_epoch_challenge_segments(
8079
) -> Optional[list[SubEpochChallengeSegment]]: ...
8180
async def get_generator(self, header_hash: bytes32) -> Optional[bytes]: ...
8281
async def get_generators_at(self, heights: set[uint32]) -> dict[uint32, bytes]: ...
82+
def get_block_heights_in_main_chain(self) -> AsyncIterator[int]: ...
8383

8484
# Coin store reads
8585
async def get_coin_records(self, names: Collection[bytes32]) -> list[CoinRecord]: ...
@@ -94,3 +94,4 @@ def get_hash(self, height: uint32) -> bytes32: ...
9494
def rollback_height_map(self, height: uint32) -> None: ...
9595
def update_height_map(self, height: uint32, block_hash: bytes32, ses: Optional[SubEpochSummary]) -> None: ...
9696
async def maybe_flush_height_map(self) -> None: ...
97+
def rollback_cache_block(self, header_hash: bytes32) -> None: ...

chia/full_node/consensus_store_sqlite3.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
from collections.abc import Collection
55
from contextlib import AbstractAsyncContextManager
6-
from typing import Any, Optional, TYPE_CHECKING
6+
from typing import Any, AsyncIterator, Optional, TYPE_CHECKING
77

88
from chia_rs import BlockRecord, FullBlock, SubEpochChallengeSegment, SubEpochSummary
99
from chia_rs.sized_bytes import bytes32
@@ -24,9 +24,6 @@ def __init__(self, block_store: BlockStore, coin_store: CoinStore):
2424
async def add_full_block(self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord) -> None:
2525
await self._block_store.add_full_block(header_hash, block, block_record)
2626

27-
def rollback_cache_block(self, header_hash: bytes32) -> None:
28-
self._block_store.rollback_cache_block(header_hash)
29-
3027
async def rollback(self, height: int) -> None:
3128
await self._block_store.rollback(height)
3229

@@ -54,6 +51,7 @@ async def new_block(
5451
) -> None:
5552
await self._coin_store.new_block(height, timestamp, included_reward_coins, tx_additions, tx_removals)
5653

54+
5755
@dataclasses.dataclass
5856
class ConsensusStoreSQLite3:
5957
"""
@@ -64,9 +62,10 @@ class ConsensusStoreSQLite3:
6462
coin_store: CoinStore
6563
height_map: BlockHeightMap
6664

67-
# Writer context and writer facade for transactional writes
65+
# Writer context and writer facade for transactional writes (re-entrant via depth counter)
6866
_writer_ctx: Optional[AbstractAsyncContextManager[Any]] = None
6967
_writer: Optional[Any] = None
68+
_txn_depth: int = 0
7069

7170
@classmethod
7271
async def create(
@@ -88,22 +87,34 @@ async def create(
8887

8988
# Async context manager yielding a writer for atomic writes
9089
async def __aenter__(self):
91-
# Begin a transaction via the block_store. CoinStore shares the same DB.
92-
self._writer_ctx = self.block_store.transaction()
93-
await self._writer_ctx.__aenter__()
94-
# Create and return a writer facade bound to this transaction
95-
self._writer = ConsensusStoreSQLite3Writer(self.block_store, self.coin_store)
96-
return self._writer
90+
# Re-entrant async context manager:
91+
# Begin a transaction only at the outermost level. CoinStore shares the same DB.
92+
if self._txn_depth == 0:
93+
self._writer_ctx = self.block_store.transaction()
94+
await self._writer_ctx.__aenter__()
95+
# Create writer facade bound to this transaction
96+
self._writer = ConsensusStoreSQLite3Writer(self.block_store, self.coin_store)
97+
self._txn_depth += 1
98+
return self._writer # Return the same writer for nested contexts
9799

98100
async def __aexit__(self, exc_type, exc, tb):
99-
# Commit on success, rollback on exception handled by transaction manager
100101
try:
101-
if self._writer_ctx is not None:
102-
return await self._writer_ctx.__aexit__(exc_type, exc, tb)
103-
return None
102+
# Check if we're at the outermost level before decrementing
103+
if self._txn_depth == 1:
104+
# This is the outermost context, handle transaction exit
105+
if self._writer_ctx is not None:
106+
return await self._writer_ctx.__aexit__(exc_type, exc, tb)
107+
return None
108+
else:
109+
# This is a nested context, just return None (don't suppress exceptions)
110+
return None
104111
finally:
105-
self._writer_ctx = None
106-
self._writer = None
112+
# Always decrement depth and clean up if we're at the outermost level
113+
if self._txn_depth > 0:
114+
self._txn_depth -= 1
115+
if self._txn_depth == 0:
116+
self._writer_ctx = None
117+
self._writer = None
107118

108119
# Block store methods
109120

@@ -154,6 +165,16 @@ async def get_coins_added_at_height(self, height: uint32) -> list[CoinRecord]:
154165
async def get_coins_removed_at_height(self, height: uint32) -> list[CoinRecord]:
155166
return await self.coin_store.get_coins_removed_at_height(height)
156167

168+
def get_block_heights_in_main_chain(self) -> AsyncIterator[int]:
169+
async def gen():
170+
async with self.block_store.transaction() as conn:
171+
async with conn.execute("SELECT height, in_main_chain FROM full_blocks") as cursor:
172+
async for row in cursor:
173+
if row[1]:
174+
yield row[0]
175+
176+
return gen()
177+
157178
# Height map methods
158179
def get_ses_heights(self) -> list[uint32]:
159180
return self.height_map.get_ses_heights()
@@ -179,9 +200,12 @@ async def maybe_flush_height_map(self) -> None:
179200
# BlockHeightMap.maybe_flush is asynchronous
180201
await self.height_map.maybe_flush()
181202

182-
if TYPE_CHECKING:
203+
def rollback_cache_block(self, header_hash: bytes32) -> None:
204+
self.block_store.rollback_cache_block(header_hash)
183205

184-
from typing import cast
206+
207+
if TYPE_CHECKING:
208+
from typing import cast
185209
from chia.consensus.consensus_store_protocol import ConsensusStoreProtocol
186210

187211
def _protocol_check(o: ConsensusStoreProtocol) -> None: ...

0 commit comments

Comments
 (0)