diff --git a/chia/_tests/plot_sync/test_sync_simulated.py b/chia/_tests/plot_sync/test_sync_simulated.py index 1be1b8a51591..7063bf3b74ff 100644 --- a/chia/_tests/plot_sync/test_sync_simulated.py +++ b/chia/_tests/plot_sync/test_sync_simulated.py @@ -25,6 +25,7 @@ from chia.plot_sync.sender import Sender from chia.plot_sync.util import Constants from chia.plotting.manager import PlotManager +from chia.plotting.prover import V1Prover from chia.plotting.util import PlotInfo from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse from chia.protocols.outbound_message import make_msg @@ -68,17 +69,17 @@ async def run( initial: bool, ) -> None: for plot_info in loaded: - assert plot_info.prover.get_filename() not in self.plots + assert Path(plot_info.prover.get_filename()) not in self.plots for plot_info in removed: - assert plot_info.prover.get_filename() in self.plots + assert Path(plot_info.prover.get_filename()) in self.plots self.invalid = invalid self.keys_missing = keys_missing self.duplicates = duplicates - removed_paths: list[Path] = [p.prover.get_filename() for p in removed] if removed is not None else [] - invalid_dict: dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid} - keys_missing_set: set[Path] = {p.prover.get_filename() for p in self.keys_missing} + removed_paths: list[Path] = [Path(p.prover.get_filename()) for p in removed] if removed is not None else [] + invalid_dict: dict[Path, int] = {Path(p.prover.get_filename()): 0 for p in self.invalid} + keys_missing_set: set[Path] = {Path(p.prover.get_filename()) for p in self.keys_missing} duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates} # Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them @@ -122,9 +123,9 @@ async def sync_done() -> bool: await time_out_assert(60, sync_done) for plot_info in loaded: - self.plots[plot_info.prover.get_filename()] = plot_info + self.plots[Path(plot_info.prover.get_filename())] = plot_info for plot_info in removed: - del self.plots[plot_info.prover.get_filename()] + del self.plots[Path(plot_info.prover.get_filename())] def validate_plot_sync(self) -> None: assert len(self.plots) == len(self.plot_sync_receiver.plots()) @@ -284,7 +285,7 @@ def get_compression_level(self) -> uint8: return [ PlotInfo( - prover=DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26), + prover=V1Prover(DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26)), pool_public_key=None, pool_contract_puzzle_hash=None, plot_public_key=G1Element(), @@ -416,7 +417,7 @@ async def test_sync_reset_cases( # Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a # fresh sync of all available data of the plot manager for plot_info in plots[0:10]: - test_data.plots[plot_info.prover.get_filename()] = plot_info + test_data.plots[Path(plot_info.prover.get_filename())] = plot_info plot_manager.plots = test_data.plots test_data.invalid = plots[10:20] test_data.keys_missing = plots[20:30] @@ -424,8 +425,8 @@ async def test_sync_reset_cases( sender: Sender = test_runner.test_data[0].plot_sync_sender started_sync_id: uint64 = uint64(0) - plot_manager.failed_to_open_filenames = {p.prover.get_filename(): 0 for p in test_data.invalid} - plot_manager.no_key_filenames = {p.prover.get_filename() for p in test_data.keys_missing} + plot_manager.failed_to_open_filenames = {Path(p.prover.get_filename()): 0 for p in test_data.invalid} + plot_manager.no_key_filenames = {Path(p.prover.get_filename()) for p in test_data.keys_missing} async def wait_for_reset() -> bool: assert started_sync_id != 0 diff --git a/chia/_tests/plotting/test_plot_manager.py b/chia/_tests/plotting/test_plot_manager.py index 16fcb50bf66e..3108da756943 100644 --- a/chia/_tests/plotting/test_plot_manager.py +++ b/chia/_tests/plotting/test_plot_manager.py @@ -13,12 +13,14 @@ import pytest from chia_rs import G1Element from chia_rs.sized_ints import uint16, uint32 +from chiapos import DiskProver from chia._tests.plotting.util import get_test_plots from chia._tests.util.misc import boolean_datacases from chia._tests.util.time_out_assert import time_out_assert from chia.plotting.cache import CURRENT_VERSION, CacheDataV1 from chia.plotting.manager import Cache, PlotManager +from chia.plotting.prover import V1Prover from chia.plotting.util import ( PlotInfo, PlotRefreshEvents, @@ -743,6 +745,20 @@ async def test_recursive_plot_scan(environment: Environment) -> None: await env.refresh_tester.run(expected_result) +@pytest.mark.anyio +async def test_disk_prover_from_bytes(environment: Environment): + env: Environment = environment + expected_result = PlotRefreshResult() + expected_result.loaded = env.dir_1.plot_info_list() # type: ignore[assignment] + expected_result.processed = len(env.dir_1) + add_plot_directory(env.root_path, str(env.dir_1.path)) + await env.refresh_tester.run(expected_result) + _, plot_info = next(iter(env.refresh_tester.plot_manager.plots.items())) + recreated_prover = V1Prover(DiskProver.from_bytes(bytes(plot_info.prover))) + assert recreated_prover.get_id() == plot_info.prover.get_id() + assert recreated_prover.get_filename() == plot_info.prover.get_filename() + + @boolean_datacases(name="follow_links", false="no_follow", true="follow") @pytest.mark.anyio async def test_recursive_plot_scan_symlinks(environment: Environment, follow_links: bool) -> None: diff --git a/chia/_tests/plotting/test_prover.py b/chia/_tests/plotting/test_prover.py new file mode 100644 index 000000000000..592280d2df52 --- /dev/null +++ b/chia/_tests/plotting/test_prover.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from chia.plotting.prover import PlotVersion, V1Prover, V2Prover, get_prover_from_bytes, get_prover_from_file + + +class TestProver: + def test_v2_prover_init_with_nonexistent_file(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + assert prover.get_version() == PlotVersion.V2 + assert prover.get_filename() == "/nonexistent/path/test.plot2" + + def test_v2_prover_get_size_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_size() + + def test_v2_prover_get_memo_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_memo() + + def test_v2_prover_get_compression_level_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_compression_level() + + def test_v2_prover_get_id_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_id() + + def test_v2_prover_get_qualities_for_challenge_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_qualities_for_challenge(b"challenge") + + def test_v2_prover_get_full_proof_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_full_proof(b"challenge", 0) + + def test_v2_prover_bytes_raises_error(self) -> None: + prover = V2Prover("/nonexistent/path/test.plot2") + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + bytes(prover) + + def test_v2_prover_from_bytes_raises_error(self) -> None: + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + V2Prover.from_bytes(b"test_data") + + def test_get_prover_from_file(self) -> None: + prover = get_prover_from_file("/nonexistent/path/test.plot2") + assert prover.get_version() == PlotVersion.V2 + with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"): + prover.get_size() + + def test_get_prover_from_file_with_plot1_still_works(self) -> None: + with tempfile.NamedTemporaryFile(suffix=".plot", delete=False) as f: + temp_path = f.name + try: + with pytest.raises(Exception) as exc_info: + get_prover_from_file(temp_path) + assert not isinstance(exc_info.value, NotImplementedError) + finally: + Path(temp_path).unlink() + + def test_unsupported_file_extension_raises_value_error(self) -> None: + with pytest.raises(ValueError, match="Unsupported plot file"): + get_prover_from_file("/nonexistent/path/test.txt") + + +class TestV1Prover: + def test_v1_prover_get_version(self) -> None: + """Test that V1Prover.get_version() returns PlotVersion.V1""" + mock_disk_prover = MagicMock() + prover = V1Prover(mock_disk_prover) + assert prover.get_version() == PlotVersion.V1 + + +class TestGetProverFromBytes: + def test_get_prover_from_bytes_v2_plot(self) -> None: + with patch("chia.plotting.prover.V2Prover.from_bytes") as mock_v2_from_bytes: + mock_prover = MagicMock() + mock_v2_from_bytes.return_value = mock_prover + result = get_prover_from_bytes("test.plot2", b"test_data") + assert result == mock_prover + + def test_get_prover_from_bytes_v1_plot(self) -> None: + with patch("chia.plotting.prover.DiskProver") as mock_disk_prover_class: + mock_disk_prover = MagicMock() + mock_disk_prover_class.from_bytes.return_value = mock_disk_prover + result = get_prover_from_bytes("test.plot", b"test_data") + assert isinstance(result, V1Prover) + + def test_get_prover_from_bytes_unsupported_extension(self) -> None: + with pytest.raises(ValueError, match="Unsupported plot file"): + get_prover_from_bytes("test.txt", b"test_data") diff --git a/chia/harvester/harvester_api.py b/chia/harvester/harvester_api.py index b547102b4354..d238a98634eb 100644 --- a/chia/harvester/harvester_api.py +++ b/chia/harvester/harvester_api.py @@ -97,7 +97,7 @@ async def new_signage_point_harvester( loop = asyncio.get_running_loop() def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, ProofOfSpace]]: - # Uses the DiskProver object to lookup qualities. This is a blocking call, + # Uses the Prover object to lookup qualities. This is a blocking call, # so it should be run in a thread pool. try: plot_id = plot_info.prover.get_id() @@ -218,7 +218,7 @@ def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, async def lookup_challenge( filename: Path, plot_info: PlotInfo ) -> tuple[Path, list[harvester_protocol.NewProofOfSpace]]: - # Executes a DiskProverLookup in a thread pool, and returns responses + # Executes a ProverLookup in a thread pool, and returns responses all_responses: list[harvester_protocol.NewProofOfSpace] = [] if self.harvester._shut_down: return filename, [] diff --git a/chia/plotting/cache.py b/chia/plotting/cache.py index c70504d4c31f..2c5dfbdd6a72 100644 --- a/chia/plotting/cache.py +++ b/chia/plotting/cache.py @@ -7,13 +7,16 @@ from dataclasses import dataclass, field from math import ceil from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from chia.plotting.prover import ProverProtocol from chia_rs import G1Element from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint16, uint64 -from chiapos import DiskProver +from chia.plotting.prover import get_prover_from_bytes from chia.plotting.util import parse_plot_info from chia.types.blockchain_format.proof_of_space import generate_plot_public_key from chia.util.streamable import Streamable, VersionedBlob, streamable @@ -43,7 +46,7 @@ class CacheDataV1(Streamable): @dataclass class CacheEntry: - prover: DiskProver + prover: ProverProtocol farmer_public_key: G1Element pool_public_key: Optional[G1Element] pool_contract_puzzle_hash: Optional[bytes32] @@ -51,7 +54,7 @@ class CacheEntry: last_use: float @classmethod - def from_disk_prover(cls, prover: DiskProver) -> CacheEntry: + def from_prover(cls, prover: ProverProtocol) -> CacheEntry: ( pool_public_key_or_puzzle_hash, farmer_public_key, @@ -149,8 +152,9 @@ def load(self) -> None: 39: 44367, } for path, cache_entry in cache_data.entries: + prover: ProverProtocol = get_prover_from_bytes(path, cache_entry.prover_data) new_entry = CacheEntry( - DiskProver.from_bytes(cache_entry.prover_data), + prover, cache_entry.farmer_public_key, cache_entry.pool_public_key, cache_entry.pool_contract_puzzle_hash, diff --git a/chia/plotting/check_plots.py b/chia/plotting/check_plots.py index fc4f1197bd5f..0fe7c2567fa5 100644 --- a/chia/plotting/check_plots.py +++ b/chia/plotting/check_plots.py @@ -9,7 +9,7 @@ from typing import Optional from chia_rs import G1Element -from chia_rs.sized_ints import uint32 +from chia_rs.sized_ints import uint8, uint32 from chiapos import Verifier from chia.plotting.manager import PlotManager @@ -133,7 +133,7 @@ def check_plots( log.info("") log.info("") log.info(f"Starting to test each plot with {num} challenges each\n") - total_good_plots: Counter[str] = Counter() + total_good_plots: Counter[uint8] = Counter() total_size = 0 bad_plots_list: list[Path] = [] diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index f2a9ab8565e5..5cd9acb6eb08 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -9,10 +9,11 @@ from typing import Any, Callable, Optional from chia_rs import G1Element -from chiapos import DiskProver, decompressor_context_queue +from chiapos import decompressor_context_queue from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size from chia.plotting.cache import Cache, CacheEntry +from chia.plotting.prover import get_prover_from_file from chia.plotting.util import ( HarvestingMode, PlotInfo, @@ -323,7 +324,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: cache_entry = self.cache.get(file_path) cache_hit = cache_entry is not None if not cache_hit: - prover = DiskProver(str(file_path)) + prover = get_prover_from_file(str(file_path)) log.debug(f"process_file {file_path!s}") @@ -343,7 +344,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]: ) return None - cache_entry = CacheEntry.from_disk_prover(prover) + cache_entry = CacheEntry.from_prover(prover) self.cache.update(file_path, cache_entry) assert cache_entry is not None diff --git a/chia/plotting/prover.py b/chia/plotting/prover.py new file mode 100644 index 000000000000..cd9474b0d9cd --- /dev/null +++ b/chia/plotting/prover.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from enum import IntEnum +from typing import TYPE_CHECKING, ClassVar, Protocol, cast + +from chia_rs.sized_bytes import bytes32 +from chia_rs.sized_ints import uint8 +from chiapos import DiskProver + +if TYPE_CHECKING: + from chiapos import DiskProver + + +class PlotVersion(IntEnum): + """Enum for plot format versions""" + + V1 = 1 + V2 = 2 + + +class ProverProtocol(Protocol): + def get_filename(self) -> str: ... + def get_size(self) -> uint8: ... + def get_memo(self) -> bytes: ... + def get_compression_level(self) -> uint8: ... + def get_version(self) -> PlotVersion: ... + def __bytes__(self) -> bytes: ... + def get_id(self) -> bytes32: ... + def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: ... + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: ... + + @classmethod + def from_bytes(cls, data: bytes) -> ProverProtocol: ... + + +class V2Prover: + """Placeholder for future V2 plot format support""" + + if TYPE_CHECKING: + _protocol_check: ClassVar[ProverProtocol] = cast("V2Prover", None) + + def __init__(self, filename: str): + self._filename = filename + + def get_filename(self) -> str: + return str(self._filename) + + def get_size(self) -> uint8: + # TODO: todo_v2_plots get k size from plot + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_memo(self) -> bytes: + # TODO: todo_v2_plots + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_compression_level(self) -> uint8: + # TODO: todo_v2_plots implement compression level retrieval + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_version(self) -> PlotVersion: + return PlotVersion.V2 + + def __bytes__(self) -> bytes: + # TODO: todo_v2_plots Implement prover serialization for caching + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_id(self) -> bytes32: + # TODO: Extract plot ID from V2 plot file + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]: + # TODO: todo_v2_plots Implement plot quality lookup + raise NotImplementedError("V2 plot format is not yet implemented") + + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + # TODO: todo_v2_plots Implement plot proof generation + raise NotImplementedError("V2 plot format is not yet implemented") + + @classmethod + def from_bytes(cls, data: bytes) -> V2Prover: + # TODO: todo_v2_plots Implement prover deserialization from cache + raise NotImplementedError("V2 plot format is not yet implemented") + + +class V1Prover: + """Wrapper for existing DiskProver to implement ProverProtocol""" + + if TYPE_CHECKING: + _protocol_check: ClassVar[ProverProtocol] = cast("V1Prover", None) + + def __init__(self, disk_prover: DiskProver) -> None: + self._disk_prover = disk_prover + + def get_filename(self) -> str: + return str(self._disk_prover.get_filename()) + + def get_size(self) -> uint8: + return uint8(self._disk_prover.get_size()) + + def get_memo(self) -> bytes: + return bytes(self._disk_prover.get_memo()) + + def get_compression_level(self) -> uint8: + return uint8(self._disk_prover.get_compression_level()) + + def get_version(self) -> PlotVersion: + return PlotVersion.V1 + + def __bytes__(self) -> bytes: + return bytes(self._disk_prover) + + def get_id(self) -> bytes32: + return bytes32(self._disk_prover.get_id()) + + def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: + return [bytes32(quality) for quality in self._disk_prover.get_qualities_for_challenge(challenge)] + + def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: + return bytes(self._disk_prover.get_full_proof(challenge, index, parallel_read)) + + @classmethod + def from_bytes(cls, data: bytes) -> V1Prover: + return cls(DiskProver.from_bytes(data)) + + +def get_prover_from_bytes(filename: str, prover_data: bytes) -> ProverProtocol: + if filename.endswith(".plot2"): + return V2Prover.from_bytes(prover_data) + elif filename.endswith(".plot"): + return V1Prover(DiskProver.from_bytes(prover_data)) + else: + raise ValueError(f"Unsupported plot file: {filename}") + + +def get_prover_from_file(filename: str) -> ProverProtocol: + if filename.endswith(".plot2"): + return V2Prover(filename) + elif filename.endswith(".plot"): + return V1Prover(DiskProver(filename)) + else: + raise ValueError(f"Unsupported plot file: {filename}") diff --git a/chia/plotting/util.py b/chia/plotting/util.py index c2ad2a136e05..a7aa314dc44f 100644 --- a/chia/plotting/util.py +++ b/chia/plotting/util.py @@ -4,12 +4,14 @@ from dataclasses import dataclass, field from enum import Enum, IntEnum from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from chia.plotting.prover import ProverProtocol from chia_rs import G1Element, PrivateKey from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint32 -from chiapos import DiskProver from typing_extensions import final from chia.util.config import load_config, lock_and_load_config, save_config @@ -39,7 +41,7 @@ class PlotsRefreshParameter(Streamable): @dataclass class PlotInfo: - prover: DiskProver + prover: ProverProtocol pool_public_key: Optional[G1Element] pool_contract_puzzle_hash: Optional[bytes32] plot_public_key: G1Element @@ -233,16 +235,22 @@ def get_filenames(directory: Path, recursive: bool, follow_links: bool) -> list[ if follow_links and recursive: import glob - files = glob.glob(str(directory / "**" / "*.plot"), recursive=True) - for file in files: + v1_file_strs = glob.glob(str(directory / "**" / "*.plot"), recursive=True) + v2_file_strs = glob.glob(str(directory / "**" / "*.plot2"), recursive=True) + + for file in v1_file_strs + v2_file_strs: filepath = Path(file).resolve() if filepath.is_file() and not filepath.name.startswith("._"): all_files.append(filepath) else: glob_function = directory.rglob if recursive else directory.glob - all_files = [ + v1_files: list[Path] = [ child for child in glob_function("*.plot") if child.is_file() and not child.name.startswith("._") ] + v2_files: list[Path] = [ + child for child in glob_function("*.plot2") if child.is_file() and not child.name.startswith("._") + ] + all_files = v1_files + v2_files log.debug(f"get_filenames: {len(all_files)} files found in {directory}, recursive: {recursive}") except Exception as e: log.warning(f"Error reading directory {directory} {e}")