3
3
import dataclasses
4
4
from collections .abc import Collection
5
5
from contextlib import AbstractAsyncContextManager
6
- from typing import Any , Optional , TYPE_CHECKING
6
+ from typing import Any , AsyncIterator , Optional , TYPE_CHECKING
7
7
8
8
from chia_rs import BlockRecord , FullBlock , SubEpochChallengeSegment , SubEpochSummary
9
9
from chia_rs .sized_bytes import bytes32
@@ -24,9 +24,6 @@ def __init__(self, block_store: BlockStore, coin_store: CoinStore):
24
24
async def add_full_block (self , header_hash : bytes32 , block : FullBlock , block_record : BlockRecord ) -> None :
25
25
await self ._block_store .add_full_block (header_hash , block , block_record )
26
26
27
- def rollback_cache_block (self , header_hash : bytes32 ) -> None :
28
- self ._block_store .rollback_cache_block (header_hash )
29
-
30
27
async def rollback (self , height : int ) -> None :
31
28
await self ._block_store .rollback (height )
32
29
@@ -54,6 +51,7 @@ async def new_block(
54
51
) -> None :
55
52
await self ._coin_store .new_block (height , timestamp , included_reward_coins , tx_additions , tx_removals )
56
53
54
+
57
55
@dataclasses .dataclass
58
56
class ConsensusStoreSQLite3 :
59
57
"""
@@ -64,9 +62,10 @@ class ConsensusStoreSQLite3:
64
62
coin_store : CoinStore
65
63
height_map : BlockHeightMap
66
64
67
- # Writer context and writer facade for transactional writes
65
+ # Writer context and writer facade for transactional writes (re-entrant via depth counter)
68
66
_writer_ctx : Optional [AbstractAsyncContextManager [Any ]] = None
69
67
_writer : Optional [Any ] = None
68
+ _txn_depth : int = 0
70
69
71
70
@classmethod
72
71
async def create (
@@ -88,22 +87,34 @@ async def create(
88
87
89
88
# Async context manager yielding a writer for atomic writes
90
89
async def __aenter__ (self ):
91
- # Begin a transaction via the block_store. CoinStore shares the same DB.
92
- self ._writer_ctx = self .block_store .transaction ()
93
- await self ._writer_ctx .__aenter__ ()
94
- # Create and return a writer facade bound to this transaction
95
- self ._writer = ConsensusStoreSQLite3Writer (self .block_store , self .coin_store )
96
- return self ._writer
90
+ # Re-entrant async context manager:
91
+ # Begin a transaction only at the outermost level. CoinStore shares the same DB.
92
+ if self ._txn_depth == 0 :
93
+ self ._writer_ctx = self .block_store .transaction ()
94
+ await self ._writer_ctx .__aenter__ ()
95
+ # Create writer facade bound to this transaction
96
+ self ._writer = ConsensusStoreSQLite3Writer (self .block_store , self .coin_store )
97
+ self ._txn_depth += 1
98
+ return self ._writer # Return the same writer for nested contexts
97
99
98
100
async def __aexit__ (self , exc_type , exc , tb ):
99
- # Commit on success, rollback on exception handled by transaction manager
100
101
try :
101
- if self ._writer_ctx is not None :
102
- return await self ._writer_ctx .__aexit__ (exc_type , exc , tb )
103
- return None
102
+ # Check if we're at the outermost level before decrementing
103
+ if self ._txn_depth == 1 :
104
+ # This is the outermost context, handle transaction exit
105
+ if self ._writer_ctx is not None :
106
+ return await self ._writer_ctx .__aexit__ (exc_type , exc , tb )
107
+ return None
108
+ else :
109
+ # This is a nested context, just return None (don't suppress exceptions)
110
+ return None
104
111
finally :
105
- self ._writer_ctx = None
106
- self ._writer = None
112
+ # Always decrement depth and clean up if we're at the outermost level
113
+ if self ._txn_depth > 0 :
114
+ self ._txn_depth -= 1
115
+ if self ._txn_depth == 0 :
116
+ self ._writer_ctx = None
117
+ self ._writer = None
107
118
108
119
# Block store methods
109
120
@@ -154,6 +165,16 @@ async def get_coins_added_at_height(self, height: uint32) -> list[CoinRecord]:
154
165
async def get_coins_removed_at_height (self , height : uint32 ) -> list [CoinRecord ]:
155
166
return await self .coin_store .get_coins_removed_at_height (height )
156
167
168
+ def get_block_heights_in_main_chain (self ) -> AsyncIterator [int ]:
169
+ async def gen ():
170
+ async with self .block_store .transaction () as conn :
171
+ async with conn .execute ("SELECT height, in_main_chain FROM full_blocks" ) as cursor :
172
+ async for row in cursor :
173
+ if row [1 ]:
174
+ yield row [0 ]
175
+
176
+ return gen ()
177
+
157
178
# Height map methods
158
179
def get_ses_heights (self ) -> list [uint32 ]:
159
180
return self .height_map .get_ses_heights ()
@@ -179,9 +200,12 @@ async def maybe_flush_height_map(self) -> None:
179
200
# BlockHeightMap.maybe_flush is asynchronous
180
201
await self .height_map .maybe_flush ()
181
202
182
- if TYPE_CHECKING :
203
+ def rollback_cache_block (self , header_hash : bytes32 ) -> None :
204
+ self .block_store .rollback_cache_block (header_hash )
183
205
184
- from typing import cast
206
+
207
+ if TYPE_CHECKING :
208
+ from typing import cast
185
209
from chia .consensus .consensus_store_protocol import ConsensusStoreProtocol
186
210
187
211
def _protocol_check (o : ConsensusStoreProtocol ) -> None : ...
0 commit comments