Skip to content

Commit ae77f25

Browse files
committed
mypy
1 parent 0094c6c commit ae77f25

File tree

3 files changed

+64
-36
lines changed

3 files changed

+64
-36
lines changed
Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from contextlib import AbstractAsyncContextManager
34
from typing import Optional, Protocol
45

56
from chia_rs import BlockRecord, FullBlock, SubEpochChallengeSegment
@@ -11,62 +12,37 @@
1112

1213
class BlockStoreProtocol(Protocol):
1314
async def add_full_block(self, header_hash: bytes32, block: FullBlock, block_record: BlockRecord) -> None: ...
14-
15+
def get_block_from_cache(self, header_hash: bytes32) -> Optional[FullBlock]: ...
1516
async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]: ...
16-
1717
async def get_full_block_bytes(self, header_hash: bytes32) -> Optional[bytes]: ...
18-
1918
async def get_full_blocks_at(self, heights: list[uint32]) -> list[FullBlock]: ...
20-
2119
async def get_block_info(self, header_hash: bytes32) -> Optional[GeneratorBlockInfo]: ...
22-
2320
async def get_generator(self, header_hash: bytes32) -> Optional[bytes]: ...
24-
2521
async def get_generators_at(self, heights: set[uint32]) -> dict[uint32, bytes]: ...
26-
2722
async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]: ...
28-
29-
async def get_block_records_in_range(
30-
self,
31-
start: int,
32-
stop: int,
33-
) -> dict[bytes32, BlockRecord]: ...
34-
23+
async def get_block_records_in_range(self, start: int, stop: int) -> dict[bytes32, BlockRecord]: ...
3524
async def get_block_records_by_hash(self, header_hashes: list[bytes32]) -> list[BlockRecord]: ...
36-
3725
async def get_block_bytes_by_hash(self, header_hashes: list[bytes32]) -> list[bytes]: ...
38-
3926
async def get_blocks_by_hash(self, header_hashes: list[bytes32]) -> list[FullBlock]: ...
40-
4127
async def get_peak(self) -> Optional[tuple[bytes32, uint32]]: ...
42-
43-
async def get_block_bytes_in_range(
44-
self,
45-
start: int,
46-
stop: int,
47-
) -> list[bytes]: ...
48-
28+
async def get_block_bytes_in_range(self, start: int, stop: int) -> list[bytes]: ...
4929
async def get_random_not_compactified(self, number: int) -> list[int]: ...
50-
5130
async def persist_sub_epoch_challenge_segments(
5231
self, ses_block_hash: bytes32, segments: list[SubEpochChallengeSegment]
5332
) -> None: ...
54-
5533
async def get_sub_epoch_challenge_segments(
5634
self,
5735
ses_block_hash: bytes32,
5836
) -> Optional[list[SubEpochChallengeSegment]]: ...
59-
6037
async def rollback(self, height: int) -> None: ...
61-
38+
def rollback_cache_block(self, header_hash: bytes32) -> None: ...
6239
async def set_in_chain(self, header_hashes: list[tuple[bytes32]]) -> None: ...
63-
6440
async def set_peak(self, header_hash: bytes32) -> None: ...
65-
6641
async def is_fully_compactified(self, header_hash: bytes32) -> Optional[bool]: ...
67-
6842
async def replace_proof(self, header_hash: bytes32, block: FullBlock) -> None: ...
69-
7043
async def count_compactified_blocks(self) -> int: ...
71-
7244
async def count_uncompactified_blocks(self) -> int: ...
45+
async def get_block_records_close_to_peak(self, blocks_n: int) -> tuple[dict[bytes32, BlockRecord], Optional[bytes32]]: ...
46+
def get_host_parameter_limit(self) -> int: ...
47+
async def get_prev_hash(self, header_hash: bytes32) -> bytes32: ...
48+
def transaction(self) -> AbstractAsyncContextManager[None]: ...

chia/consensus/blockchain.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ async def add_block(
421421

422422
try:
423423
# Always add the block to the database
424-
async with self.block_store.db_wrapper.writer():
424+
async with self.block_store.transaction():
425425
# Perform the DB operations to update the state, and rollback if something goes wrong
426426
await self.block_store.add_full_block(header_hash, block, block_record)
427427
records, state_change_summary = await self._reconsider_peak(block_record, genesis, fork_info)
@@ -883,7 +883,7 @@ async def get_header_blocks_in_range(
883883

884884
blocks: list[FullBlock] = []
885885
for hash in hashes.copy():
886-
block = self.block_store.block_cache.get(hash)
886+
block = self.block_store.get_block_from_cache(hash)
887887
if block is not None:
888888
blocks.append(block)
889889
hashes.remove(hash)
@@ -932,7 +932,7 @@ async def get_block_records_at(self, heights: list[uint32], batch_size: int = 90
932932
"""
933933
records: list[BlockRecord] = []
934934
hashes: list[bytes32] = []
935-
assert batch_size < self.block_store.db_wrapper.host_parameter_limit
935+
assert batch_size < self.block_store.get_host_parameter_limit()
936936
for height in heights:
937937
header_hash: Optional[bytes32] = self.height_to_hash(height)
938938
if header_hash is None:

chia/full_node/block_store.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,39 @@ async def get_sub_epoch_challenge_segments(
189189
return challenge_segments
190190
return None
191191

192+
def get_host_parameter_limit(self) -> int:
193+
return self.db_wrapper.host_parameter_limit
194+
195+
def transaction(self) -> AbstractAsyncContextManager[None]:
196+
return self.db_wrapper.writer()
197+
198+
def get_block_from_cache(self, header_hash: bytes32) -> Optional[FullBlock]:
199+
return self.block_cache.get(header_hash)
200+
201+
async def get_block_records_close_to_peak(
202+
self, blocks_n: int
203+
) -> tuple[dict[bytes32, BlockRecord], Optional[bytes32]]:
204+
"""
205+
Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
206+
peak header hash. Only blocks that are part of the main chain/current peak are included.
207+
"""
208+
209+
peak = await self.get_peak()
210+
if peak is None:
211+
return {}, None
212+
213+
ret: dict[bytes32, BlockRecord] = {}
214+
async with self.db_wrapper.reader_no_transaction() as conn:
215+
async with conn.execute(
216+
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND in_main_chain=1",
217+
(peak[1] - blocks_n,),
218+
) as cursor:
219+
for row in await cursor.fetchall():
220+
header_hash = bytes32(row[0])
221+
ret[header_hash] = BlockRecord.from_bytes(row[1])
222+
223+
return ret, peak[0]
224+
192225
def rollback_cache_block(self, header_hash: bytes32) -> None:
193226
try:
194227
self.block_cache.remove(header_hash)
@@ -197,6 +230,25 @@ def rollback_cache_block(self, header_hash: bytes32) -> None:
197230
# block to the cache yet
198231
pass
199232

233+
async def get_prev_hash(self, header_hash: bytes32) -> bytes32:
234+
"""
235+
Returns the header hash preceeding the input header hash.
236+
Throws an exception if the block is not present
237+
"""
238+
cached = self.block_cache.get(header_hash)
239+
if cached is not None:
240+
return cached.prev_header_hash
241+
242+
async with self.db_wrapper.reader_no_transaction() as conn:
243+
async with conn.execute(
244+
"SELECT prev_hash FROM full_blocks WHERE header_hash=?",
245+
(header_hash,),
246+
) as cursor:
247+
row = await cursor.fetchone()
248+
if row is None:
249+
raise KeyError("missing block in chain")
250+
return bytes32(row[0])
251+
200252
async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]:
201253
cached: Optional[FullBlock] = self.block_cache.get(header_hash)
202254
if cached is not None:

0 commit comments

Comments
 (0)