From 77413376f4f442eecde2621915b6bb2554750e52 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Wed, 16 Jul 2025 15:34:16 +0300 Subject: [PATCH 01/11] prover protocol and v2Prover --- chia/_tests/plot_sync/test_plot_sync.py | 8 +- chia/_tests/plot_sync/test_sync_simulated.py | 37 ++-- chia/_tests/plotting/test_plot_manager.py | 6 +- chia/plot_sync/sender.py | 2 +- chia/plotting/cache.py | 15 +- chia/plotting/check_plots.py | 4 +- chia/plotting/manager.py | 7 +- chia/plotting/prover.py | 184 +++++++++++++++++++ chia/plotting/util.py | 20 +- 9 files changed, 241 insertions(+), 42 deletions(-) create mode 100644 chia/plotting/prover.py diff --git a/chia/_tests/plot_sync/test_plot_sync.py b/chia/_tests/plot_sync/test_plot_sync.py index 9e76d26b7a5a..e889286ea300 100644 --- a/chia/_tests/plot_sync/test_plot_sync.py +++ b/chia/_tests/plot_sync/test_plot_sync.py @@ -65,7 +65,7 @@ class ExpectedResult: def add_valid(self, list_plots: list[MockPlotInfo]) -> None: def create_mock_plot(info: MockPlotInfo) -> Plot: return Plot( - info.prover.get_filename(), + str(info.prover.get_filename()), uint8(0), bytes32.zeros, None, @@ -77,7 +77,7 @@ def create_mock_plot(info: MockPlotInfo) -> Plot: ) self.valid_count += len(list_plots) - self.valid_delta.additions.update({x.prover.get_filename(): create_mock_plot(x) for x in list_plots}) + self.valid_delta.additions.update({str(x.prover.get_filename()): create_mock_plot(x) for x in list_plots}) def remove_valid(self, list_paths: list[Path]) -> None: self.valid_count -= len(list_paths) @@ -193,7 +193,7 @@ async def plot_sync_callback(self, peer_id: bytes32, delta: Optional[Delta]) -> assert path in delta.valid.additions plot = harvester.plot_manager.plots.get(Path(path), None) assert plot is not None - assert plot.prover.get_filename() == delta.valid.additions[path].filename + assert plot.prover.get_filename_str() == delta.valid.additions[path].filename assert plot.prover.get_size() == delta.valid.additions[path].size assert plot.prover.get_id() == delta.valid.additions[path].plot_id assert plot.prover.get_compression_level() == delta.valid.additions[path].compression_level @@ -254,7 +254,7 @@ async def run_sync_test(self) -> None: assert expected.duplicates_delta.empty() for path, plot_info in plot_manager.plots.items(): assert str(path) in receiver.plots() - assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename + assert plot_info.prover.get_filename_str() == receiver.plots()[str(path)].filename assert plot_info.prover.get_size() == receiver.plots()[str(path)].size assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id assert plot_info.prover.get_compression_level() == receiver.plots()[str(path)].compression_level diff --git a/chia/_tests/plot_sync/test_sync_simulated.py b/chia/_tests/plot_sync/test_sync_simulated.py index 1be1b8a51591..c1228e8a3653 100644 --- a/chia/_tests/plot_sync/test_sync_simulated.py +++ b/chia/_tests/plot_sync/test_sync_simulated.py @@ -25,6 +25,7 @@ from chia.plot_sync.sender import Sender from chia.plot_sync.util import Constants from chia.plotting.manager import PlotManager +from chia.plotting.prover import V1Prover from chia.plotting.util import PlotInfo from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse from chia.protocols.outbound_message import make_msg @@ -79,7 +80,7 @@ async def run( removed_paths: list[Path] = [p.prover.get_filename() for p in removed] if removed is not None else [] invalid_dict: dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid} keys_missing_set: set[Path] = {p.prover.get_filename() for p in self.keys_missing} - duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates} + duplicates_set: set[str] = {p.prover.get_filename_str() for p in self.duplicates} # Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them # to sync them to the farmer. @@ -131,30 +132,30 @@ def validate_plot_sync(self) -> None: assert len(self.invalid) == len(self.plot_sync_receiver.invalid()) assert len(self.keys_missing) == len(self.plot_sync_receiver.keys_missing()) for _, plot_info in self.plots.items(): - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename() in self.plot_sync_receiver.plots() - synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename()] - assert plot_info.prover.get_filename() == synced_plot.filename + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.plots() + synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename_str()] + assert plot_info.prover.get_filename_str() == synced_plot.filename assert plot_info.pool_public_key == synced_plot.pool_public_key assert plot_info.pool_contract_puzzle_hash == synced_plot.pool_contract_puzzle_hash assert plot_info.plot_public_key == synced_plot.plot_public_key assert plot_info.file_size == synced_plot.file_size assert uint64(int(plot_info.time_modified)) == synced_plot.time_modified for plot_info in self.invalid: - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() - assert plot_info.prover.get_filename() in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.duplicates() for plot_info in self.keys_missing: - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename() in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.duplicates() for plot_info in self.duplicates: - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename() in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.duplicates() @dataclass @@ -284,7 +285,7 @@ def get_compression_level(self) -> uint8: return [ PlotInfo( - prover=DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26), + prover=V1Prover(DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26)), pool_public_key=None, pool_contract_puzzle_hash=None, plot_public_key=G1Element(), diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index 16fcb50bf66e..beb8a192eda6 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -42,8 +42,8 @@ class MockDiskProver: filename: str - def get_filename(self) -> str: - return self.filename + def get_filename(self) -> Path: + return Path(self.filename) @dataclass @@ -614,7 +614,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None: # Write the modified cache entries to the file cache_path.write_bytes(bytes(VersionedBlob(uint16(CURRENT_VERSION), bytes(cache_data)))) # And now test that plots in invalid_entries are not longer loaded - assert_cache([plot_info for plot_info in plot_infos if plot_info.prover.get_filename() not in invalid_entries]) + assert_cache([plot_info for plot_info in plot_infos if str(plot_info.prover.get_filename()) not in invalid_entries]) @pytest.mark.anyio diff --git a/chia/plot_sync/sender.py b/chia/plot_sync/sender.py index 74be15cae2cf..ec1c51c3937b 100644 --- a/chia/plot_sync/sender.py +++ b/chia/plot_sync/sender.py @@ -39,7 +39,7 @@ def _convert_plot_info_list(plot_infos: list[PlotInfo]) -> list[Plot]: for plot_info in plot_infos: converted.append( Plot( - filename=plot_info.prover.get_filename(), + filename=plot_info.prover.get_filename_str(), size=plot_info.prover.get_size(), plot_id=plot_info.prover.get_id(), pool_public_key=plot_info.pool_public_key, diff --git a/chia/plotting/cache.py b/chia/plotting/cache.py index c70504d4c31f..4d5a5aef8905 100644 --- a/chia/plotting/cache.py +++ b/chia/plotting/cache.py @@ -7,13 +7,16 @@ from dataclasses import dataclass, field from math import ceil from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from chia.plotting.prover import ProverProtocol from chia_rs import G1Element from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint16, uint64 -from chiapos import DiskProver +from chia.plotting.prover import get_prover_from_bytes from chia.plotting.util import parse_plot_info from chia.types.blockchain_format.proof_of_space import generate_plot_public_key from chia.util.streamable import Streamable, VersionedBlob, streamable @@ -43,7 +46,7 @@ class CacheDataV1(Streamable): @dataclass class CacheEntry: - prover: DiskProver + prover: ProverProtocol farmer_public_key: G1Element pool_public_key: Optional[G1Element] pool_contract_puzzle_hash: Optional[bytes32] @@ -51,7 +54,8 @@ class CacheEntry: last_use: float @classmethod - def from_disk_prover(cls, prover: DiskProver) -> CacheEntry: + def from_prover(cls, prover: ProverProtocol) -> CacheEntry: + """Create CacheEntry from any prover implementation""" ( pool_public_key_or_puzzle_hash, farmer_public_key, @@ -149,8 +153,9 @@ def load(self) -> None: 39: 44367, } for path, cache_entry in cache_data.entries: + prover: ProverProtocol = get_prover_from_bytes(path, cache_entry.prover_data) new_entry = CacheEntry( - DiskProver.from_bytes(cache_entry.prover_data), + prover, cache_entry.farmer_public_key, cache_entry.pool_public_key, cache_entry.pool_contract_puzzle_hash, diff --git a/chia/plotting/check_plots.py b/chia/plotting/check_plots.py index fc4f1197bd5f..0fe7c2567fa5 100644 --- a/chia/plotting/check_plots.py +++ b/chia/plotting/check_plots.py @@ -9,7 +9,7 @@ from typing import Optional from chia_rs import G1Element -from chia_rs.sized_ints import uint32 +from chia_rs.sized_ints import uint8, uint32 from chiapos import Verifier from chia.plotting.manager import PlotManager @@ -133,7 +133,7 @@ def check_plots( log.info("") log.info("") log.info(f"Starting to test each plot with {num} challenges each\n") - total_good_plots: Counter[str] = Counter() + total_good_plots: Counter[uint8] = Counter() total_size = 0 bad_plots_list: list[Path] = [] diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index f2a9ab8565e5..5cd9acb6eb08 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -9,10 +9,11 @@ from typing import Any, Callable, Optional from chia_rs import G1Element -from chiapos import DiskProver, decompressor_context_queue +from chiapos import decompressor_context_queue from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size from chia.plotting.cache import Cache, CacheEntry +from chia.plotting.prover import get_prover_from_file from chia.plotting.util import ( HarvestingMode, PlotInfo, @@ -323,7 +324,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: cache_entry = self.cache.get(file_path) cache_hit = cache_entry is not None if not cache_hit: - prover = DiskProver(str(file_path)) + prover = get_prover_from_file(str(file_path)) log.debug(f"process_file {file_path!s}") @@ -343,7 +344,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: ) return None - cache_entry = CacheEntry.from_disk_prover(prover) + cache_entry = CacheEntry.from_prover(prover) self.cache.update(file_path, cache_entry) assert cache_entry is not None diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py new file mode 100644 index 000000000000..df70248b8dd8 --- /dev/null +++ b/chia/plotting/prover.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING + +from chia_rs.sized_bytes import bytes32 +from chia_rs.sized_ints import uint8 +from chiapos import DiskProver + +if TYPE_CHECKING: + from chiapos import DiskProver + + +class ProverProtocol(ABC): + """Abstract protocol for all prover implementations (V1 and V2)""" + + @abstractmethod + def get_filename(self) -> Path: + """Returns the filename of the plot""" + + @abstractmethod + def get_filename_str(self) -> str: + """Returns the filename of the plot""" + + @abstractmethod + def get_size(self) -> uint8: + """Returns the k size of the plot""" + + @abstractmethod + def get_memo(self) -> bytes: + """Returns the memo containing keys and other metadata""" + + @abstractmethod + def get_compression_level(self) -> uint8: + """Returns the compression level (0 for uncompressed)""" + + @abstractmethod + def get_version(self) -> int: + """Returns the plot version (1 for V1, 2 for V2)""" + + @abstractmethod + def __bytes__(self) -> bytes: + """Returns the prover serialized as bytes for caching""" + + @abstractmethod + def get_id(self) -> bytes32: + """Returns the plot ID""" + + @abstractmethod + def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: + """Returns the qualities for a given challenge""" + + @abstractmethod + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + """Returns the full proof for a given challenge and index""" + + @classmethod + @abstractmethod + def from_bytes(cls, data: bytes) -> ProverProtocol: + """Create a prover from serialized bytes""" + + +class V2Prover(ProverProtocol): + """V2 Plot Prover implementation - currently stubbed""" + + def __init__(self, filename: str): + self._filename = filename + # TODO: Implement V2 plot file parsing and validation + + def get_filename(self) -> Path: + return Path(self._filename) + + def get_filename_str(self) -> str: + return str(self._filename) + + def get_size(self) -> uint8: + # TODO: Extract k size from V2 plot file + return uint8(32) # Stub value + + def get_memo(self) -> bytes: + # TODO: Extract memo from V2 plot file + return b"" # Stub value + + def get_compression_level(self) -> uint8: + # TODO: Extract compression level from V2 plot file + return uint8(0) # Stub value + + def get_version(self) -> int: + return 2 + + def __bytes__(self) -> bytes: + # TODO: Implement proper V2 prover serialization for caching + # For now, just serialize the filename as a placeholder + return self._filename.encode("utf-8") + + def get_id(self) -> bytes32: + # TODO: Extract plot ID from V2 plot file + return bytes32(b"") # Stub value + + def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]: + # TODO: Implement V2 plot quality lookup + return [] # Stub value + + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + # TODO: Implement V2 plot proof generation + return b"" # Stub value + + @classmethod + def from_bytes(cls, data: bytes) -> V2Prover: + # TODO: Implement proper V2 prover deserialization from cache + # For now, just deserialize the filename + filename = data.decode("utf-8") + return cls(filename) + + +class V1Prover(ProverProtocol): + """Wrapper for existing DiskProver to implement ProverProtocol""" + + def __init__(self, disk_prover: DiskProver) -> None: + self._disk_prover = disk_prover + + def get_filename(self) -> Path: + return Path(self._disk_prover.get_filename()) + + def get_filename_str(self) -> str: + return str(self._disk_prover.get_filename()) + + def get_size(self) -> uint8: + return uint8(self._disk_prover.get_size()) + + def get_memo(self) -> bytes: + return bytes(self._disk_prover.get_memo()) + + def get_compression_level(self) -> uint8: + return uint8(self._disk_prover.get_compression_level()) + + def get_version(self) -> int: + return 1 + + def __bytes__(self) -> bytes: + return bytes(self._disk_prover) + + def get_id(self) -> bytes32: + return bytes32(self._disk_prover.get_id()) + + def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: + return [bytes32(quality) for quality in self._disk_prover.get_qualities_for_challenge(challenge)] + + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + return bytes(self._disk_prover.get_full_proof(challenge, index, parallel_read)) + + @classmethod + def from_bytes(cls, data: bytes) -> V1Prover: + """Create V1ProverWrapper from serialized bytes""" + from chiapos import DiskProver + + disk_prover = DiskProver.from_bytes(data) + return cls(disk_prover) + + @property + def disk_prover(self) -> DiskProver: + """Access to underlying DiskProver for backwards compatibility""" + return self._disk_prover + + +def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: + """Factory function to create appropriate prover based on plot version""" + if filename.endswith(".plot2"): + return V2Prover.from_bytes(prover_data) + elif filename.endswith(".plot"): + return V1Prover(DiskProver.from_bytes(prover_data)) + else: + raise ValueError(f"Unsupported plot file: {filename}") + + +def get_prover_from_file(filename: str) -> ProverProtocol: + """Factory function to create appropriate prover based on plot version""" + if filename.endswith(".plot2"): + return V2Prover(filename) + elif filename.endswith(".plot"): + return V1Prover(DiskProver(filename)) + else: + raise ValueError(f"Unsupported plot file: {filename}") diff --git a/chia/plotting/util.py b/chia/plotting/util.py index c2ad2a136e05..a7aa314dc44f 100644 --- a/chia/plotting/util.py +++ b/chia/plotting/util.py @@ -4,12 +4,14 @@ from dataclasses import dataclass, field from enum import Enum, IntEnum from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from chia.plotting.prover import ProverProtocol from chia_rs import G1Element, PrivateKey from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint32 -from chiapos import DiskProver from typing_extensions import final from chia.util.config import load_config, lock_and_load_config, save_config @@ -39,7 +41,7 @@ class PlotsRefreshParameter(Streamable): @dataclass class PlotInfo: - prover: DiskProver + prover: ProverProtocol pool_public_key: Optional[G1Element] pool_contract_puzzle_hash: Optional[bytes32] plot_public_key: G1Element @@ -233,16 +235,22 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ if follow_links and recursive: import glob - files = glob.glob(str(directory / "**" / "*.plot"), recursive=True) - for file in files: + v1_file_strs = glob.glob(str(directory / "**" / "*.plot"), recursive=True) + v2_file_strs = glob.glob(str(directory / "**" / "*.plot2"), recursive=True) + + for file in v1_file_strs + v2_file_strs: filepath = Path(file).resolve() if filepath.is_file() and not filepath.name.startswith("._"): all_files.append(filepath) else: glob_function = directory.rglob if recursive else directory.glob - all_files = [ + v1_files: list[Path] = [ child for child in glob_function("*.plot") if child.is_file() and not child.name.startswith("._") ] + v2_files: list[Path] = [ + child for child in glob_function("*.plot2") if child.is_file() and not child.name.startswith("._") + ] + all_files = v1_files + v2_files log.debug(f"get_filenames: {len(all_files)} files found in {directory}, recursive: {recursive}") except Exception as e: log.warning(f"Error reading directory {directory} {e}") From 8c4f57843f7d1def844f0778db39603e48bebf02 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Wed, 16 Jul 2025 15:57:50 +0300 Subject: [PATCH 02/11] format name --- chia/plotting/prover.py | 4 ++-- chia/plotting/util.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index df70248b8dd8..f9de5d9b440d 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -166,7 +166,7 @@ def disk_prover(self) -> DiskProver: def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: """Factory function to create appropriate prover based on plot version""" - if filename.endswith(".plot2"): + if filename.endswith(".plot_v2"): return V2Prover.from_bytes(prover_data) elif filename.endswith(".plot"): return V1Prover(DiskProver.from_bytes(prover_data)) @@ -176,7 +176,7 @@ def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: def get_prover_from_file(filename: str) -> ProverProtocol: """Factory function to create appropriate prover based on plot version""" - if filename.endswith(".plot2"): + if filename.endswith(".plot_v2"): return V2Prover(filename) elif filename.endswith(".plot"): return V1Prover(DiskProver(filename)) diff --git a/chia/plotting/util.py b/chia/plotting/util.py index a7aa314dc44f..f195c255f4ae 100644 --- a/chia/plotting/util.py +++ b/chia/plotting/util.py @@ -236,7 +236,7 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ import glob v1_file_strs = glob.glob(str(directory / "**" / "*.plot"), recursive=True) - v2_file_strs = glob.glob(str(directory / "**" / "*.plot2"), recursive=True) + v2_file_strs = glob.glob(str(directory / "**" / "*.plot_v2"), recursive=True) for file in v1_file_strs + v2_file_strs: filepath = Path(file).resolve() @@ -248,7 +248,7 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ child for child in glob_function("*.plot") if child.is_file() and not child.name.startswith("._") ] v2_files: list[Path] = [ - child for child in glob_function("*.plot2") if child.is_file() and not child.name.startswith("._") + child for child in glob_function("*.plot_v2") if child.is_file() and not child.name.startswith("._") ] all_files = v1_files + v2_files log.debug(f"get_filenames: {len(all_files)} files found in {directory}, recursive: {recursive}") From 6aef294b6102db614f63e20bb0268712f6dd2548 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Wed, 16 Jul 2025 17:47:49 +0300 Subject: [PATCH 03/11] format --- chia/plotting/cache.py | 2 +- chia/plotting/prover.py | 36 ++++++++++++++---------------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/chia/plotting/cache.py b/chia/plotting/cache.py index 4d5a5aef8905..d442ac5b07de 100644 --- a/chia/plotting/cache.py +++ b/chia/plotting/cache.py @@ -55,7 +55,7 @@ class CacheEntry: @classmethod def from_prover(cls, prover: ProverProtocol) -> CacheEntry: - """Create CacheEntry from any prover implementation""" + """Create CacheEntry from prover""" ( pool_public_key_or_puzzle_hash, farmer_public_key, diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index f9de5d9b440d..0c0bc671343a 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -13,15 +13,13 @@ class ProverProtocol(ABC): - """Abstract protocol for all prover implementations (V1 and V2)""" - @abstractmethod def get_filename(self) -> Path: - """Returns the filename of the plot""" + """Returns the filename for the plot""" @abstractmethod def get_filename_str(self) -> str: - """Returns the filename of the plot""" + """Returns the filename string for the plot""" @abstractmethod def get_size(self) -> uint8: @@ -29,19 +27,19 @@ def get_size(self) -> uint8: @abstractmethod def get_memo(self) -> bytes: - """Returns the memo containing keys and other metadata""" + """Returns the memo""" @abstractmethod def get_compression_level(self) -> uint8: - """Returns the compression level (0 for uncompressed)""" + """Returns the compression level""" @abstractmethod def get_version(self) -> int: - """Returns the plot version (1 for V1, 2 for V2)""" + """Returns the plot version""" @abstractmethod def __bytes__(self) -> bytes: - """Returns the prover serialized as bytes for caching""" + """Returns the prover bytes""" @abstractmethod def get_id(self) -> bytes32: @@ -62,11 +60,11 @@ def from_bytes(cls, data: bytes) -> ProverProtocol: class V2Prover(ProverProtocol): - """V2 Plot Prover implementation - currently stubbed""" + """V2 Plot Prover stubb""" def __init__(self, filename: str): self._filename = filename - # TODO: Implement V2 plot file parsing and validation + # TODO: todo_v2_plots Implement plot file parsing and validation def get_filename(self) -> Path: return Path(self._filename) @@ -75,11 +73,11 @@ def get_filename_str(self) -> str: return str(self._filename) def get_size(self) -> uint8: - # TODO: Extract k size from V2 plot file + # TODO: todo_v2_plots get k size from plot return uint8(32) # Stub value def get_memo(self) -> bytes: - # TODO: Extract memo from V2 plot file + # TODO: todo_v2_plots return b"" # Stub value def get_compression_level(self) -> uint8: @@ -90,7 +88,7 @@ def get_version(self) -> int: return 2 def __bytes__(self) -> bytes: - # TODO: Implement proper V2 prover serialization for caching + # TODO: todo_v2_plots Implement prover serialization for caching # For now, just serialize the filename as a placeholder return self._filename.encode("utf-8") @@ -99,17 +97,15 @@ def get_id(self) -> bytes32: return bytes32(b"") # Stub value def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]: - # TODO: Implement V2 plot quality lookup + # TODO: todo_v2_plots Implement plot quality lookup return [] # Stub value def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: - # TODO: Implement V2 plot proof generation - return b"" # Stub value + # TODO: todo_v2_plots Implement plot proof generation + return b"" @classmethod def from_bytes(cls, data: bytes) -> V2Prover: - # TODO: Implement proper V2 prover deserialization from cache - # For now, just deserialize the filename filename = data.decode("utf-8") return cls(filename) @@ -152,7 +148,6 @@ def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = Tru @classmethod def from_bytes(cls, data: bytes) -> V1Prover: - """Create V1ProverWrapper from serialized bytes""" from chiapos import DiskProver disk_prover = DiskProver.from_bytes(data) @@ -160,12 +155,10 @@ def from_bytes(cls, data: bytes) -> V1Prover: @property def disk_prover(self) -> DiskProver: - """Access to underlying DiskProver for backwards compatibility""" return self._disk_prover def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: - """Factory function to create appropriate prover based on plot version""" if filename.endswith(".plot_v2"): return V2Prover.from_bytes(prover_data) elif filename.endswith(".plot"): @@ -175,7 +168,6 @@ def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: def get_prover_from_file(filename: str) -> ProverProtocol: - """Factory function to create appropriate prover based on plot version""" if filename.endswith(".plot_v2"): return V2Prover(filename) elif filename.endswith(".plot"): From 3f3b134acb521324856813afe0fa9e4881dd92d7 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Wed, 16 Jul 2025 23:06:22 +0300 Subject: [PATCH 04/11] refactor filename --- chia/plotting/prover.py | 4 ++-- chia/plotting/util.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index 0c0bc671343a..b844fd52204c 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -159,7 +159,7 @@ def disk_prover(self) -> DiskProver: def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: - if filename.endswith(".plot_v2"): + if filename.endswith(".plot2"): return V2Prover.from_bytes(prover_data) elif filename.endswith(".plot"): return V1Prover(DiskProver.from_bytes(prover_data)) @@ -168,7 +168,7 @@ def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: def get_prover_from_file(filename: str) -> ProverProtocol: - if filename.endswith(".plot_v2"): + if filename.endswith(".plot2"): return V2Prover(filename) elif filename.endswith(".plot"): return V1Prover(DiskProver(filename)) diff --git a/chia/plotting/util.py b/chia/plotting/util.py index f195c255f4ae..a7aa314dc44f 100644 --- a/chia/plotting/util.py +++ b/chia/plotting/util.py @@ -236,7 +236,7 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ import glob v1_file_strs = glob.glob(str(directory / "**" / "*.plot"), recursive=True) - v2_file_strs = glob.glob(str(directory / "**" / "*.plot_v2"), recursive=True) + v2_file_strs = glob.glob(str(directory / "**" / "*.plot2"), recursive=True) for file in v1_file_strs + v2_file_strs: filepath = Path(file).resolve() @@ -248,7 +248,7 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ child for child in glob_function("*.plot") if child.is_file() and not child.name.startswith("._") ] v2_files: list[Path] = [ - child for child in glob_function("*.plot_v2") if child.is_file() and not child.name.startswith("._") + child for child in glob_function("*.plot2") if child.is_file() and not child.name.startswith("._") ] all_files = v1_files + v2_files log.debug(f"get_filenames: {len(all_files)} files found in {directory}, recursive: {recursive}") From 6e28f8c04b112c350e6225eb9b8f3ee464dce68c Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Thu, 17 Jul 2025 13:03:33 +0300 Subject: [PATCH 05/11] tests/raise unimplemented --- chia/_tests/plotting/test_prover.py | 76 +++++++++++++++++++++++++++++ chia/plotting/prover.py | 46 +++++++++-------- 2 files changed, 102 insertions(+), 20 deletions(-) create mode 100644 chia/_tests/plotting/test_prover.py diff --git a/chia/_tests/plotting/test_prover.py b/chia/_tests/plotting/test_prover.py new file mode 100644 index 000000000000..68c34b8f5fe5 --- /dev/null +++ b/chia/_tests/plotting/test_prover.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest + +from chia.plotting.prover import PlotVersion, V2Prover, get_prover_from_file + + +class TestProver: + def test_v2_prover_init_with_nonexistent_file(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + assert prover.get_version() == PlotVersion.V2 + assert prover.get_filename() == Path("/nonexistent/path/test.plot2") + assert prover.get_filename_str() == "/nonexistent/path/test.plot2" + + def test_v2_prover_get_size_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_size() + + def test_v2_prover_get_memo_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_memo() + + def test_v2_prover_get_compression_level_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_compression_level() + + def test_v2_prover_get_id_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_id() + + def test_v2_prover_get_qualities_for_challenge_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_qualities_for_challenge(b"challenge") + + def test_v2_prover_get_full_proof_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_full_proof(b"challenge", 0) + + def test_v2_prover_bytes_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + bytes(prover) + + def test_v2_prover_from_bytes_raises_error(self) -> None: + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + V2Prover.from_bytes(b"test_data") + + def test_get_prover_from_file(self) -> None: + prover = get_prover_from_file("/nonexistent/path/test.plot2") + assert prover.get_version() == PlotVersion.V2 + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_size() + + def test_get_prover_from_file_with_plot1_still_works(self) -> None: + with tempfile.NamedTemporaryFile(suffix=".plot", delete=False) as f: + temp_path = f.name + try: + with pytest.raises(Exception) as exc_info: + get_prover_from_file(temp_path) + assert not isinstance(exc_info.value, NotImplementedError) + finally: + Path(temp_path).unlink() + + def test_unsupported_file_extension_raises_value_error(self) -> None: + """Test that unsupported file extensions raise ValueError""" + with pytest.raises(ValueError, match="Unsupported plot file"): + get_prover_from_file("/nonexistent/path/test.txt") diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index b844fd52204c..79fc5a174723 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import IntEnum from pathlib import Path from typing import TYPE_CHECKING @@ -12,6 +13,13 @@ from chiapos import DiskProver +class PlotVersion(IntEnum): + """Enum for plot format versions""" + + V1 = 1 + V2 = 2 + + class ProverProtocol(ABC): @abstractmethod def get_filename(self) -> Path: @@ -34,7 +42,7 @@ def get_compression_level(self) -> uint8: """Returns the compression level""" @abstractmethod - def get_version(self) -> int: + def get_version(self) -> PlotVersion: """Returns the plot version""" @abstractmethod @@ -64,7 +72,6 @@ class V2Prover(ProverProtocol): def __init__(self, filename: str): self._filename = filename - # TODO: todo_v2_plots Implement plot file parsing and validation def get_filename(self) -> Path: return Path(self._filename) @@ -74,40 +81,39 @@ def get_filename_str(self) -> str: def get_size(self) -> uint8: # TODO: todo_v2_plots get k size from plot - return uint8(32) # Stub value + raise NotImplementedError("V2 plot format is not yet implemented") def get_memo(self) -> bytes: # TODO: todo_v2_plots - return b"" # Stub value + raise NotImplementedError("V2 plot format is not yet implemented") def get_compression_level(self) -> uint8: - # TODO: Extract compression level from V2 plot file - return uint8(0) # Stub value + # TODO: todo_v2_plots implement compression level retrieval + raise NotImplementedError("V2 plot format is not yet implemented") - def get_version(self) -> int: - return 2 + def get_version(self) -> PlotVersion: + return PlotVersion.V2 def __bytes__(self) -> bytes: # TODO: todo_v2_plots Implement prover serialization for caching - # For now, just serialize the filename as a placeholder - return self._filename.encode("utf-8") + raise NotImplementedError("V2 plot format is not yet implemented") def get_id(self) -> bytes32: # TODO: Extract plot ID from V2 plot file - return bytes32(b"") # Stub value + raise NotImplementedError("V2 plot format is not yet implemented") - def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]: + def get_qualities_for_challenge(self, _challenge: bytes) -> list[bytes32]: # TODO: todo_v2_plots Implement plot quality lookup - return [] # Stub value + raise NotImplementedError("V2 plot format is not yet implemented") - def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + def get_full_proof(self, _challenge: bytes, _index: int, _parallel_read: bool = True) -> bytes: # TODO: todo_v2_plots Implement plot proof generation - return b"" + raise NotImplementedError("V2 plot format is not yet implemented") @classmethod - def from_bytes(cls, data: bytes) -> V2Prover: - filename = data.decode("utf-8") - return cls(filename) + def from_bytes(cls, _data: bytes) -> V2Prover: + # TODO: todo_v2_plots Implement prover deserialization from cache + raise NotImplementedError("V2 plot format is not yet implemented") class V1Prover(ProverProtocol): @@ -131,8 +137,8 @@ def get_memo(self) -> bytes: def get_compression_level(self) -> uint8: return uint8(self._disk_prover.get_compression_level()) - def get_version(self) -> int: - return 1 + def get_version(self) -> PlotVersion: + return PlotVersion.V1 def __bytes__(self) -> bytes: return bytes(self._disk_prover) From 7811dbb843f993ac27159b4e95f87e0be82df8ea Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Thu, 17 Jul 2025 13:51:30 +0300 Subject: [PATCH 06/11] add get_filename_str to mock --- chia/_tests/plot_sync/test_plot_sync.py | 4 ++-- chia/_tests/plotting/test_plot_manager.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/chia/_tests/plot_sync/test_plot_sync.py b/chia/_tests/plot_sync/test_plot_sync.py index e889286ea300..8d6f42ebc5a6 100644 --- a/chia/_tests/plot_sync/test_plot_sync.py +++ b/chia/_tests/plot_sync/test_plot_sync.py @@ -65,7 +65,7 @@ class ExpectedResult: def add_valid(self, list_plots: list[MockPlotInfo]) -> None: def create_mock_plot(info: MockPlotInfo) -> Plot: return Plot( - str(info.prover.get_filename()), + info.prover.get_filename_str(), uint8(0), bytes32.zeros, None, @@ -77,7 +77,7 @@ def create_mock_plot(info: MockPlotInfo) -> Plot: ) self.valid_count += len(list_plots) - self.valid_delta.additions.update({str(x.prover.get_filename()): create_mock_plot(x) for x in list_plots}) + self.valid_delta.additions.update({x.prover.get_filename_str(): create_mock_plot(x) for x in list_plots}) def remove_valid(self, list_paths: list[Path]) -> None: self.valid_count -= len(list_paths) diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index beb8a192eda6..164348524677 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -45,6 +45,9 @@ class MockDiskProver: def get_filename(self) -> Path: return Path(self.filename) + def get_filename_str(self) -> str: + return self.filename + @dataclass class MockPlotInfo: From 441aed5ef253e3e5513d9b1c4774e4b26e4fcebc Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Thu, 17 Jul 2025 13:56:00 +0300 Subject: [PATCH 07/11] rename methods --- chia/_tests/plot_sync/test_plot_sync.py | 4 +- chia/_tests/plot_sync/test_sync_simulated.py | 56 ++++++++++---------- chia/_tests/plotting/test_plot_manager.py | 4 +- chia/_tests/plotting/test_prover.py | 4 +- chia/plot_sync/sender.py | 2 +- chia/plotting/manager.py | 6 +-- chia/plotting/prover.py | 12 ++--- 7 files changed, 44 insertions(+), 44 deletions(-) diff --git a/chia/_tests/plot_sync/test_plot_sync.py b/chia/_tests/plot_sync/test_plot_sync.py index 8d6f42ebc5a6..ff63e1e519f7 100644 --- a/chia/_tests/plot_sync/test_plot_sync.py +++ b/chia/_tests/plot_sync/test_plot_sync.py @@ -193,7 +193,7 @@ async def plot_sync_callback(self, peer_id: bytes32, delta: Optional[Delta]) -> assert path in delta.valid.additions plot = harvester.plot_manager.plots.get(Path(path), None) assert plot is not None - assert plot.prover.get_filename_str() == delta.valid.additions[path].filename + assert plot.prover.get_filename() == delta.valid.additions[path].filename assert plot.prover.get_size() == delta.valid.additions[path].size assert plot.prover.get_id() == delta.valid.additions[path].plot_id assert plot.prover.get_compression_level() == delta.valid.additions[path].compression_level @@ -254,7 +254,7 @@ async def run_sync_test(self) -> None: assert expected.duplicates_delta.empty() for path, plot_info in plot_manager.plots.items(): assert str(path) in receiver.plots() - assert plot_info.prover.get_filename_str() == receiver.plots()[str(path)].filename + assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename assert plot_info.prover.get_size() == receiver.plots()[str(path)].size assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id assert plot_info.prover.get_compression_level() == receiver.plots()[str(path)].compression_level diff --git a/chia/_tests/plot_sync/test_sync_simulated.py b/chia/_tests/plot_sync/test_sync_simulated.py index c1228e8a3653..d48ab796ba09 100644 --- a/chia/_tests/plot_sync/test_sync_simulated.py +++ b/chia/_tests/plot_sync/test_sync_simulated.py @@ -69,18 +69,18 @@ async def run( initial: bool, ) -> None: for plot_info in loaded: - assert plot_info.prover.get_filename() not in self.plots + assert plot_info.prover.get_filepath() not in self.plots for plot_info in removed: - assert plot_info.prover.get_filename() in self.plots + assert plot_info.prover.get_filepath() in self.plots self.invalid = invalid self.keys_missing = keys_missing self.duplicates = duplicates - removed_paths: list[Path] = [p.prover.get_filename() for p in removed] if removed is not None else [] - invalid_dict: dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid} - keys_missing_set: set[Path] = {p.prover.get_filename() for p in self.keys_missing} - duplicates_set: set[str] = {p.prover.get_filename_str() for p in self.duplicates} + removed_paths: list[Path] = [p.prover.get_filepath() for p in removed] if removed is not None else [] + invalid_dict: dict[Path, int] = {p.prover.get_filepath(): 0 for p in self.invalid} + keys_missing_set: set[Path] = {p.prover.get_filepath() for p in self.keys_missing} + duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates} # Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them # to sync them to the farmer. @@ -91,7 +91,7 @@ async def run( # Inject duplicated plots into `PlotManager` of the harvester so that the callback calls below can use them # to sync them to the farmer. for plot_info in loaded: - plot_path = Path(plot_info.prover.get_filename()) + plot_path = Path(plot_info.prover.get_filepath()) self.harvester.plot_manager.plot_filename_paths[plot_path.name] = (str(plot_path.parent), set()) for duplicate in duplicates_set: plot_path = Path(duplicate) @@ -123,39 +123,39 @@ async def sync_done() -> bool: await time_out_assert(60, sync_done) for plot_info in loaded: - self.plots[plot_info.prover.get_filename()] = plot_info + self.plots[plot_info.prover.get_filepath()] = plot_info for plot_info in removed: - del self.plots[plot_info.prover.get_filename()] + del self.plots[plot_info.prover.get_filepath()] def validate_plot_sync(self) -> None: assert len(self.plots) == len(self.plot_sync_receiver.plots()) assert len(self.invalid) == len(self.plot_sync_receiver.invalid()) assert len(self.keys_missing) == len(self.plot_sync_receiver.keys_missing()) for _, plot_info in self.plots.items(): - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.plots() - synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename_str()] - assert plot_info.prover.get_filename_str() == synced_plot.filename + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.plots() + synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename()] + assert plot_info.prover.get_filename() == synced_plot.filename assert plot_info.pool_public_key == synced_plot.pool_public_key assert plot_info.pool_contract_puzzle_hash == synced_plot.pool_contract_puzzle_hash assert plot_info.plot_public_key == synced_plot.plot_public_key assert plot_info.file_size == synced_plot.file_size assert uint64(int(plot_info.time_modified)) == synced_plot.time_modified for plot_info in self.invalid: - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.plots() - assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() for plot_info in self.keys_missing: - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.plots() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() for plot_info in self.duplicates: - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.invalid() - assert plot_info.prover.get_filename_str() not in self.plot_sync_receiver.keys_missing() - assert plot_info.prover.get_filename_str() in self.plot_sync_receiver.duplicates() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.duplicates() @dataclass @@ -417,7 +417,7 @@ async def test_sync_reset_cases( # Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a # fresh sync of all available data of the plot manager for plot_info in plots[0:10]: - test_data.plots[plot_info.prover.get_filename()] = plot_info + test_data.plots[plot_info.prover.get_filepath()] = plot_info plot_manager.plots = test_data.plots test_data.invalid = plots[10:20] test_data.keys_missing = plots[20:30] @@ -425,8 +425,8 @@ async def test_sync_reset_cases( sender: Sender = test_runner.test_data[0].plot_sync_sender started_sync_id: uint64 = uint64(0) - plot_manager.failed_to_open_filenames = {p.prover.get_filename(): 0 for p in test_data.invalid} - plot_manager.no_key_filenames = {p.prover.get_filename() for p in test_data.keys_missing} + plot_manager.failed_to_open_filenames = {p.prover.get_filepath(): 0 for p in test_data.invalid} + plot_manager.no_key_filenames = {p.prover.get_filepath() for p in test_data.keys_missing} async def wait_for_reset() -> bool: assert started_sync_id != 0 diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index 164348524677..dbae9a0a879a 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -115,7 +115,7 @@ def refresh_callback(self, event: PlotRefreshEvents, refresh_result: PlotRefresh for value in actual_value: if type(value) is PlotInfo: for plot_info in expected_list: - if plot_info.prover.get_filename() == value.prover.get_filename(): + if plot_info.prover.get_filename() == value.prover.get_filepath(): values_found += 1 continue else: @@ -507,7 +507,7 @@ async def test_plot_info_caching(environment, bt): await refresh_tester.run(expected_result) for path, plot_info in env.refresh_tester.plot_manager.plots.items(): assert path in plot_manager.plots - assert plot_manager.plots[path].prover.get_filename() == plot_info.prover.get_filename() + assert plot_manager.plots[path].prover.get_filepath() == plot_info.prover.get_filepath() assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id() assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo() assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size() diff --git a/chia/_tests/plotting/test_prover.py b/chia/_tests/plotting/test_prover.py index 68c34b8f5fe5..68053430d906 100644 --- a/chia/_tests/plotting/test_prover.py +++ b/chia/_tests/plotting/test_prover.py @@ -12,8 +12,8 @@ class TestProver: def test_v2_prover_init_with_nonexistent_file(self) -> None: prover = V2Prover("/nonexistent/path/test.plot2") assert prover.get_version() == PlotVersion.V2 - assert prover.get_filename() == Path("/nonexistent/path/test.plot2") - assert prover.get_filename_str() == "/nonexistent/path/test.plot2" + assert prover.get_filepath() == Path("/nonexistent/path/test.plot2") + assert prover.get_filename() == "/nonexistent/path/test.plot2" def test_v2_prover_get_size_raises_error(self) -> None: prover = V2Prover("/nonexistent/path/test.plot2") diff --git a/chia/plot_sync/sender.py b/chia/plot_sync/sender.py index ec1c51c3937b..74be15cae2cf 100644 --- a/chia/plot_sync/sender.py +++ b/chia/plot_sync/sender.py @@ -39,7 +39,7 @@ def _convert_plot_info_list(plot_infos: list[PlotInfo]) -> list[Plot]: for plot_info in plot_infos: converted.append( Plot( - filename=plot_info.prover.get_filename_str(), + filename=plot_info.prover.get_filename(), size=plot_info.prover.get_size(), plot_id=plot_info.prover.get_id(), pool_public_key=plot_info.pool_public_key, diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index 5cd9acb6eb08..52fac2d8c25c 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -386,10 +386,10 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: with self.plot_filename_paths_lock: paths: Optional[tuple[str, set[str]]] = self.plot_filename_paths.get(file_path.name) if paths is None: - paths = (str(Path(cache_entry.prover.get_filename()).parent), set()) + paths = (str(Path(cache_entry.prover.get_filepath()).parent), set()) self.plot_filename_paths[file_path.name] = paths else: - paths[1].add(str(Path(cache_entry.prover.get_filename()).parent)) + paths[1].add(str(Path(cache_entry.prover.get_filepath()).parent)) log.warning(f"Have multiple copies of the plot {file_path.name} in {[paths[0], *paths[1]]}.") return None @@ -423,7 +423,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: plots_refreshed: dict[Path, PlotInfo] = {} for new_plot in executor.map(process_file, plot_paths): if new_plot is not None: - plots_refreshed[Path(new_plot.prover.get_filename())] = new_plot + plots_refreshed[Path(new_plot.prover.get_filepath())] = new_plot self.plots.update(plots_refreshed) result.duration = time.time() - start_time diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index 79fc5a174723..a3deeaffd338 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -22,11 +22,11 @@ class PlotVersion(IntEnum): class ProverProtocol(ABC): @abstractmethod - def get_filename(self) -> Path: + def get_filepath(self) -> Path: """Returns the filename for the plot""" @abstractmethod - def get_filename_str(self) -> str: + def get_filename(self) -> str: """Returns the filename string for the plot""" @abstractmethod @@ -73,10 +73,10 @@ class V2Prover(ProverProtocol): def __init__(self, filename: str): self._filename = filename - def get_filename(self) -> Path: + def get_filepath(self) -> Path: return Path(self._filename) - def get_filename_str(self) -> str: + def get_filename(self) -> str: return str(self._filename) def get_size(self) -> uint8: @@ -122,10 +122,10 @@ class V1Prover(ProverProtocol): def __init__(self, disk_prover: DiskProver) -> None: self._disk_prover = disk_prover - def get_filename(self) -> Path: + def get_filepath(self) -> Path: return Path(self._disk_prover.get_filename()) - def get_filename_str(self) -> str: + def get_filename(self) -> str: return str(self._disk_prover.get_filename()) def get_size(self) -> uint8: From bffc7326e82416764dc14a41d3104f4f2218df51 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Thu, 17 Jul 2025 13:57:22 +0300 Subject: [PATCH 08/11] rename --- chia/_tests/plot_sync/test_plot_sync.py | 4 ++-- chia/_tests/plotting/test_plot_manager.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/chia/_tests/plot_sync/test_plot_sync.py b/chia/_tests/plot_sync/test_plot_sync.py index ff63e1e519f7..9e76d26b7a5a 100644 --- a/chia/_tests/plot_sync/test_plot_sync.py +++ b/chia/_tests/plot_sync/test_plot_sync.py @@ -65,7 +65,7 @@ class ExpectedResult: def add_valid(self, list_plots: list[MockPlotInfo]) -> None: def create_mock_plot(info: MockPlotInfo) -> Plot: return Plot( - info.prover.get_filename_str(), + info.prover.get_filename(), uint8(0), bytes32.zeros, None, @@ -77,7 +77,7 @@ def create_mock_plot(info: MockPlotInfo) -> Plot: ) self.valid_count += len(list_plots) - self.valid_delta.additions.update({x.prover.get_filename_str(): create_mock_plot(x) for x in list_plots}) + self.valid_delta.additions.update({x.prover.get_filename(): create_mock_plot(x) for x in list_plots}) def remove_valid(self, list_paths: list[Path]) -> None: self.valid_count -= len(list_paths) diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index dbae9a0a879a..e19a18ab6e01 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -42,10 +42,10 @@ class MockDiskProver: filename: str - def get_filename(self) -> Path: + def get_filepath(self) -> Path: return Path(self.filename) - def get_filename_str(self) -> str: + def get_filename(self) -> str: return self.filename @@ -598,7 +598,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None: test_cache.load() assert len(test_cache) == len(expected) for plot_info in expected: - assert test_cache.get(Path(plot_info.prover.get_filename())) is not None + assert test_cache.get(Path(plot_info.prover.get_filepath())) is not None # Modify two entries, with and without memo modification, they both should remain in the cache after load modify_cache_entry(0, 1500, modify_memo=False) @@ -617,7 +617,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None: # Write the modified cache entries to the file cache_path.write_bytes(bytes(VersionedBlob(uint16(CURRENT_VERSION), bytes(cache_data)))) # And now test that plots in invalid_entries are not longer loaded - assert_cache([plot_info for plot_info in plot_infos if str(plot_info.prover.get_filename()) not in invalid_entries]) + assert_cache([plot_info for plot_info in plot_infos if str(plot_info.prover.get_filepath()) not in invalid_entries]) @pytest.mark.anyio From 17b0b3435e5c347ea1f1ad7d9e9d41cb9ec59d6c Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Sun, 20 Jul 2025 13:13:47 +0300 Subject: [PATCH 09/11] refactor --- chia/_tests/plot_sync/test_sync_simulated.py | 22 +++--- chia/_tests/plotting/test_plot_manager.py | 11 +-- chia/_tests/plotting/test_prover.py | 25 +++++- chia/harvester/harvester_api.py | 4 +- chia/plotting/cache.py | 1 - chia/plotting/manager.py | 6 +- chia/plotting/prover.py | 82 ++++++-------------- 7 files changed, 66 insertions(+), 85 deletions(-) diff --git a/chia/_tests/plot_sync/test_sync_simulated.py b/chia/_tests/plot_sync/test_sync_simulated.py index d48ab796ba09..7063bf3b74ff 100644 --- a/chia/_tests/plot_sync/test_sync_simulated.py +++ b/chia/_tests/plot_sync/test_sync_simulated.py @@ -69,17 +69,17 @@ async def run( initial: bool, ) -> None: for plot_info in loaded: - assert plot_info.prover.get_filepath() not in self.plots + assert Path(plot_info.prover.get_filename()) not in self.plots for plot_info in removed: - assert plot_info.prover.get_filepath() in self.plots + assert Path(plot_info.prover.get_filename()) in self.plots self.invalid = invalid self.keys_missing = keys_missing self.duplicates = duplicates - removed_paths: list[Path] = [p.prover.get_filepath() for p in removed] if removed is not None else [] - invalid_dict: dict[Path, int] = {p.prover.get_filepath(): 0 for p in self.invalid} - keys_missing_set: set[Path] = {p.prover.get_filepath() for p in self.keys_missing} + removed_paths: list[Path] = [Path(p.prover.get_filename()) for p in removed] if removed is not None else [] + invalid_dict: dict[Path, int] = {Path(p.prover.get_filename()): 0 for p in self.invalid} + keys_missing_set: set[Path] = {Path(p.prover.get_filename()) for p in self.keys_missing} duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates} # Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them @@ -91,7 +91,7 @@ async def run( # Inject duplicated plots into `PlotManager` of the harvester so that the callback calls below can use them # to sync them to the farmer. for plot_info in loaded: - plot_path = Path(plot_info.prover.get_filepath()) + plot_path = Path(plot_info.prover.get_filename()) self.harvester.plot_manager.plot_filename_paths[plot_path.name] = (str(plot_path.parent), set()) for duplicate in duplicates_set: plot_path = Path(duplicate) @@ -123,9 +123,9 @@ async def sync_done() -> bool: await time_out_assert(60, sync_done) for plot_info in loaded: - self.plots[plot_info.prover.get_filepath()] = plot_info + self.plots[Path(plot_info.prover.get_filename())] = plot_info for plot_info in removed: - del self.plots[plot_info.prover.get_filepath()] + del self.plots[Path(plot_info.prover.get_filename())] def validate_plot_sync(self) -> None: assert len(self.plots) == len(self.plot_sync_receiver.plots()) @@ -417,7 +417,7 @@ async def test_sync_reset_cases( # Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a # fresh sync of all available data of the plot manager for plot_info in plots[0:10]: - test_data.plots[plot_info.prover.get_filepath()] = plot_info + test_data.plots[Path(plot_info.prover.get_filename())] = plot_info plot_manager.plots = test_data.plots test_data.invalid = plots[10:20] test_data.keys_missing = plots[20:30] @@ -425,8 +425,8 @@ async def test_sync_reset_cases( sender: Sender = test_runner.test_data[0].plot_sync_sender started_sync_id: uint64 = uint64(0) - plot_manager.failed_to_open_filenames = {p.prover.get_filepath(): 0 for p in test_data.invalid} - plot_manager.no_key_filenames = {p.prover.get_filepath() for p in test_data.keys_missing} + plot_manager.failed_to_open_filenames = {Path(p.prover.get_filename()): 0 for p in test_data.invalid} + plot_manager.no_key_filenames = {Path(p.prover.get_filename()) for p in test_data.keys_missing} async def wait_for_reset() -> bool: assert started_sync_id != 0 diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index e19a18ab6e01..16fcb50bf66e 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -42,9 +42,6 @@ class MockDiskProver: filename: str - def get_filepath(self) -> Path: - return Path(self.filename) - def get_filename(self) -> str: return self.filename @@ -115,7 +112,7 @@ def refresh_callback(self, event: PlotRefreshEvents, refresh_result: PlotRefresh for value in actual_value: if type(value) is PlotInfo: for plot_info in expected_list: - if plot_info.prover.get_filename() == value.prover.get_filepath(): + if plot_info.prover.get_filename() == value.prover.get_filename(): values_found += 1 continue else: @@ -507,7 +504,7 @@ async def test_plot_info_caching(environment, bt): await refresh_tester.run(expected_result) for path, plot_info in env.refresh_tester.plot_manager.plots.items(): assert path in plot_manager.plots - assert plot_manager.plots[path].prover.get_filepath() == plot_info.prover.get_filepath() + assert plot_manager.plots[path].prover.get_filename() == plot_info.prover.get_filename() assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id() assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo() assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size() @@ -598,7 +595,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None: test_cache.load() assert len(test_cache) == len(expected) for plot_info in expected: - assert test_cache.get(Path(plot_info.prover.get_filepath())) is not None + assert test_cache.get(Path(plot_info.prover.get_filename())) is not None # Modify two entries, with and without memo modification, they both should remain in the cache after load modify_cache_entry(0, 1500, modify_memo=False) @@ -617,7 +614,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None: # Write the modified cache entries to the file cache_path.write_bytes(bytes(VersionedBlob(uint16(CURRENT_VERSION), bytes(cache_data)))) # And now test that plots in invalid_entries are not longer loaded - assert_cache([plot_info for plot_info in plot_infos if str(plot_info.prover.get_filepath()) not in invalid_entries]) + assert_cache([plot_info for plot_info in plot_infos if plot_info.prover.get_filename() not in invalid_entries]) @pytest.mark.anyio diff --git a/chia/_tests/plotting/test_prover.py b/chia/_tests/plotting/test_prover.py index 68053430d906..ce5f30e2cf96 100644 --- a/chia/_tests/plotting/test_prover.py +++ b/chia/_tests/plotting/test_prover.py @@ -2,17 +2,17 @@ import tempfile from pathlib import Path +from unittest.mock import MagicMock, patch import pytest -from chia.plotting.prover import PlotVersion, V2Prover, get_prover_from_file +from chia.plotting.prover import PlotVersion, V1Prover, V2Prover, get_prover_from_bytes, get_prover_from_file class TestProver: def test_v2_prover_init_with_nonexistent_file(self) -> None: prover = V2Prover("/nonexistent/path/test.plot2") assert prover.get_version() == PlotVersion.V2 - assert prover.get_filepath() == Path("/nonexistent/path/test.plot2") assert prover.get_filename() == "/nonexistent/path/test.plot2" def test_v2_prover_get_size_raises_error(self) -> None: @@ -71,6 +71,25 @@ def test_get_prover_from_file_with_plot1_still_works(self) -> None: Path(temp_path).unlink() def test_unsupported_file_extension_raises_value_error(self) -> None: - """Test that unsupported file extensions raise ValueError""" with pytest.raises(ValueError, match="Unsupported plot file"): get_prover_from_file("/nonexistent/path/test.txt") + + +class TestGetProverFromBytes: + def test_get_prover_from_bytes_v2_plot(self) -> None: + with patch("chia.plotting.prover.V2Prover.from_bytes") as mock_v2_from_bytes: + mock_prover = MagicMock() + mock_v2_from_bytes.return_value = mock_prover + result = get_prover_from_bytes("test.plot2", b"test_data") + assert result == mock_prover + + def test_get_prover_from_bytes_v1_plot(self) -> None: + with patch("chia.plotting.prover.DiskProver") as mock_disk_prover_class: + mock_disk_prover = MagicMock() + mock_disk_prover_class.from_bytes.return_value = mock_disk_prover + result = get_prover_from_bytes("test.plot", b"test_data") + assert isinstance(result, V1Prover) + + def test_get_prover_from_bytes_unsupported_extension(self) -> None: + with pytest.raises(ValueError, match="Unsupported plot file"): + get_prover_from_bytes("test.txt", b"test_data") diff --git a/chia/harvester/harvester_api.py b/chia/harvester/harvester_api.py index b547102b4354..d238a98634eb 100644 --- a/chia/harvester/harvester_api.py +++ b/chia/harvester/harvester_api.py @@ -97,7 +97,7 @@ async def new_signage_point_harvester( loop = asyncio.get_running_loop() def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, ProofOfSpace]]: - # Uses the DiskProver object to lookup qualities. This is a blocking call, + # Uses the Prover object to lookup qualities. This is a blocking call, # so it should be run in a thread pool. try: plot_id = plot_info.prover.get_id() @@ -218,7 +218,7 @@ def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, async def lookup_challenge( filename: Path, plot_info: PlotInfo ) -> tuple[Path, list[harvester_protocol.NewProofOfSpace]]: - # Executes a DiskProverLookup in a thread pool, and returns responses + # Executes a ProverLookup in a thread pool, and returns responses all_responses: list[harvester_protocol.NewProofOfSpace] = [] if self.harvester._shut_down: return filename, [] diff --git a/chia/plotting/cache.py b/chia/plotting/cache.py index d442ac5b07de..2c5dfbdd6a72 100644 --- a/chia/plotting/cache.py +++ b/chia/plotting/cache.py @@ -55,7 +55,6 @@ class CacheEntry: @classmethod def from_prover(cls, prover: ProverProtocol) -> CacheEntry: - """Create CacheEntry from prover""" ( pool_public_key_or_puzzle_hash, farmer_public_key, diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index 52fac2d8c25c..5cd9acb6eb08 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -386,10 +386,10 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: with self.plot_filename_paths_lock: paths: Optional[tuple[str, set[str]]] = self.plot_filename_paths.get(file_path.name) if paths is None: - paths = (str(Path(cache_entry.prover.get_filepath()).parent), set()) + paths = (str(Path(cache_entry.prover.get_filename()).parent), set()) self.plot_filename_paths[file_path.name] = paths else: - paths[1].add(str(Path(cache_entry.prover.get_filepath()).parent)) + paths[1].add(str(Path(cache_entry.prover.get_filename()).parent)) log.warning(f"Have multiple copies of the plot {file_path.name} in {[paths[0], *paths[1]]}.") return None @@ -423,7 +423,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: plots_refreshed: dict[Path, PlotInfo] = {} for new_plot in executor.map(process_file, plot_paths): if new_plot is not None: - plots_refreshed[Path(new_plot.prover.get_filepath())] = new_plot + plots_refreshed[Path(new_plot.prover.get_filename())] = new_plot self.plots.update(plots_refreshed) result.duration = time.time() - start_time diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index a3deeaffd338..5d095af07843 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -1,9 +1,7 @@ from __future__ import annotations -from abc import ABC, abstractmethod from enum import IntEnum -from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, Protocol, cast from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint8 @@ -20,62 +18,30 @@ class PlotVersion(IntEnum): V2 = 2 -class ProverProtocol(ABC): - @abstractmethod - def get_filepath(self) -> Path: - """Returns the filename for the plot""" - - @abstractmethod - def get_filename(self) -> str: - """Returns the filename string for the plot""" - - @abstractmethod - def get_size(self) -> uint8: - """Returns the k size of the plot""" - - @abstractmethod - def get_memo(self) -> bytes: - """Returns the memo""" - - @abstractmethod - def get_compression_level(self) -> uint8: - """Returns the compression level""" - - @abstractmethod - def get_version(self) -> PlotVersion: - """Returns the plot version""" - - @abstractmethod - def __bytes__(self) -> bytes: - """Returns the prover bytes""" - - @abstractmethod - def get_id(self) -> bytes32: - """Returns the plot ID""" - - @abstractmethod - def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: - """Returns the qualities for a given challenge""" - - @abstractmethod - def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: - """Returns the full proof for a given challenge and index""" +class ProverProtocol(Protocol): + def get_filename(self) -> str: ... + def get_size(self) -> uint8: ... + def get_memo(self) -> bytes: ... + def get_compression_level(self) -> uint8: ... + def get_version(self) -> PlotVersion: ... + def __bytes__(self) -> bytes: ... + def get_id(self) -> bytes32: ... + def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: ... + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: ... @classmethod - @abstractmethod - def from_bytes(cls, data: bytes) -> ProverProtocol: - """Create a prover from serialized bytes""" + def from_bytes(cls, data: bytes) -> ProverProtocol: ... -class V2Prover(ProverProtocol): - """V2 Plot Prover stubb""" +class V2Prover: + """Placeholder for future V2 plot format support""" + + if TYPE_CHECKING: + _protocol_check: ClassVar[ProverProtocol] = cast("V2Prover", None) def __init__(self, filename: str): self._filename = filename - def get_filepath(self) -> Path: - return Path(self._filename) - def get_filename(self) -> str: return str(self._filename) @@ -102,29 +68,29 @@ def get_id(self) -> bytes32: # TODO: Extract plot ID from V2 plot file raise NotImplementedError("V2 plot format is not yet implemented") - def get_qualities_for_challenge(self, _challenge: bytes) -> list[bytes32]: + def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]: # TODO: todo_v2_plots Implement plot quality lookup raise NotImplementedError("V2 plot format is not yet implemented") - def get_full_proof(self, _challenge: bytes, _index: int, _parallel_read: bool = True) -> bytes: + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: # TODO: todo_v2_plots Implement plot proof generation raise NotImplementedError("V2 plot format is not yet implemented") @classmethod - def from_bytes(cls, _data: bytes) -> V2Prover: + def from_bytes(cls, data: bytes) -> V2Prover: # TODO: todo_v2_plots Implement prover deserialization from cache raise NotImplementedError("V2 plot format is not yet implemented") -class V1Prover(ProverProtocol): +class V1Prover: """Wrapper for existing DiskProver to implement ProverProtocol""" + if TYPE_CHECKING: + _protocol_check: ClassVar[ProverProtocol] = cast("V1Prover", None) + def __init__(self, disk_prover: DiskProver) -> None: self._disk_prover = disk_prover - def get_filepath(self) -> Path: - return Path(self._disk_prover.get_filename()) - def get_filename(self) -> str: return str(self._disk_prover.get_filename()) From 42f33480006cf7890ab824c26059651e326b931f Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Sun, 20 Jul 2025 15:08:59 +0300 Subject: [PATCH 10/11] improve coverage --- chia/_tests/plotting/test_prover.py | 8 ++++++++ chia/plotting/prover.py | 9 +-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/chia/_tests/plotting/test_prover.py b/chia/_tests/plotting/test_prover.py index ce5f30e2cf96..592280d2df52 100644 --- a/chia/_tests/plotting/test_prover.py +++ b/chia/_tests/plotting/test_prover.py @@ -75,6 +75,14 @@ def test_unsupported_file_extension_raises_value_error(self) -> None: get_prover_from_file("/nonexistent/path/test.txt") +class TestV1Prover: + def test_v1_prover_get_version(self) -> None: + """Test that V1Prover.get_version() returns PlotVersion.V1""" + mock_disk_prover = MagicMock() + prover = V1Prover(mock_disk_prover) + assert prover.get_version() == PlotVersion.V1 + + class TestGetProverFromBytes: def test_get_prover_from_bytes_v2_plot(self) -> None: with patch("chia.plotting.prover.V2Prover.from_bytes") as mock_v2_from_bytes: diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py index 5d095af07843..cd9474b0d9cd 100644 --- a/chia/plotting/prover.py +++ b/chia/plotting/prover.py @@ -120,14 +120,7 @@ def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = Tru @classmethod def from_bytes(cls, data: bytes) -> V1Prover: - from chiapos import DiskProver - - disk_prover = DiskProver.from_bytes(data) - return cls(disk_prover) - - @property - def disk_prover(self) -> DiskProver: - return self._disk_prover + return cls(DiskProver.from_bytes(data)) def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: From e3a8cfc875e9ffaa32385e3d17be106406e6fd48 Mon Sep 17 00:00:00 2001 From: almogdepaz Date: Tue, 22 Jul 2025 12:06:31 +0300 Subject: [PATCH 11/11] test from bytes --- chia/_tests/plotting/test_plot_manager.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index 16fcb50bf66e..3108da756943 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -13,12 +13,14 @@ import pytest from chia_rs import G1Element from chia_rs.sized_ints import uint16, uint32 +from chiapos import DiskProver from chia._tests.plotting.util import get_test_plots from chia._tests.util.misc import boolean_datacases from chia._tests.util.time_out_assert import time_out_assert from chia.plotting.cache import CURRENT_VERSION, CacheDataV1 from chia.plotting.manager import Cache, PlotManager +from chia.plotting.prover import V1Prover from chia.plotting.util import ( PlotInfo, PlotRefreshEvents, @@ -743,6 +745,20 @@ async def test_recursive_plot_scan(environment: Environment) -> None: await env.refresh_tester.run(expected_result) +@pytest.mark.anyio +async def test_disk_prover_from_bytes(environment: Environment): + env: Environment = environment + expected_result = PlotRefreshResult() + expected_result.loaded = env.dir_1.plot_info_list() # type: ignore[assignment] + expected_result.processed = len(env.dir_1) + add_plot_directory(env.root_path, str(env.dir_1.path)) + await env.refresh_tester.run(expected_result) + _, plot_info = next(iter(env.refresh_tester.plot_manager.plots.items())) + recreated_prover = V1Prover(DiskProver.from_bytes(bytes(plot_info.prover))) + assert recreated_prover.get_id() == plot_info.prover.get_id() + assert recreated_prover.get_filename() == plot_info.prover.get_filename() + + @boolean_datacases(name="follow_links", false="no_follow", true="follow") @pytest.mark.anyio async def test_recursive_plot_scan_symlinks(environment: Environment, follow_links: bool) -> None: