Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions chia/_tests/plot_sync/test_sync_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from chia.plot_sync.sender import Sender
from chia.plot_sync.util import Constants
from chia.plotting.manager import PlotManager
from chia.plotting.prover import V1Prover
from chia.plotting.util import PlotInfo
from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse
from chia.protocols.outbound_message import make_msg
Expand Down Expand Up @@ -68,17 +69,17 @@ async def run(
initial: bool,
) -> None:
for plot_info in loaded:
assert plot_info.prover.get_filename() not in self.plots
assert Path(plot_info.prover.get_filename()) not in self.plots
for plot_info in removed:
assert plot_info.prover.get_filename() in self.plots
assert Path(plot_info.prover.get_filename()) in self.plots

self.invalid = invalid
self.keys_missing = keys_missing
self.duplicates = duplicates

removed_paths: list[Path] = [p.prover.get_filename() for p in removed] if removed is not None else []
invalid_dict: dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid}
keys_missing_set: set[Path] = {p.prover.get_filename() for p in self.keys_missing}
removed_paths: list[Path] = [Path(p.prover.get_filename()) for p in removed] if removed is not None else []
invalid_dict: dict[Path, int] = {Path(p.prover.get_filename()): 0 for p in self.invalid}
keys_missing_set: set[Path] = {Path(p.prover.get_filename()) for p in self.keys_missing}
duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates}

# Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them
Expand Down Expand Up @@ -122,9 +123,9 @@ async def sync_done() -> bool:
await time_out_assert(60, sync_done)

for plot_info in loaded:
self.plots[plot_info.prover.get_filename()] = plot_info
self.plots[Path(plot_info.prover.get_filename())] = plot_info
for plot_info in removed:
del self.plots[plot_info.prover.get_filename()]
del self.plots[Path(plot_info.prover.get_filename())]

def validate_plot_sync(self) -> None:
assert len(self.plots) == len(self.plot_sync_receiver.plots())
Expand Down Expand Up @@ -284,7 +285,7 @@ def get_compression_level(self) -> uint8:

return [
PlotInfo(
prover=DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26),
prover=V1Prover(DiskProver(f"{x}", bytes32.random(seeded_random), 25 + x % 26)),
pool_public_key=None,
pool_contract_puzzle_hash=None,
plot_public_key=G1Element(),
Expand Down Expand Up @@ -416,16 +417,16 @@ async def test_sync_reset_cases(
# Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a
# fresh sync of all available data of the plot manager
for plot_info in plots[0:10]:
test_data.plots[plot_info.prover.get_filename()] = plot_info
test_data.plots[Path(plot_info.prover.get_filename())] = plot_info
plot_manager.plots = test_data.plots
test_data.invalid = plots[10:20]
test_data.keys_missing = plots[20:30]
test_data.plot_sync_receiver.simulate_error = simulate_error # type: ignore[attr-defined]
sender: Sender = test_runner.test_data[0].plot_sync_sender
started_sync_id: uint64 = uint64(0)

plot_manager.failed_to_open_filenames = {p.prover.get_filename(): 0 for p in test_data.invalid}
plot_manager.no_key_filenames = {p.prover.get_filename() for p in test_data.keys_missing}
plot_manager.failed_to_open_filenames = {Path(p.prover.get_filename()): 0 for p in test_data.invalid}
plot_manager.no_key_filenames = {Path(p.prover.get_filename()) for p in test_data.keys_missing}

async def wait_for_reset() -> bool:
assert started_sync_id != 0
Expand Down
16 changes: 16 additions & 0 deletions chia/_tests/plotting/test_plot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import pytest
from chia_rs import G1Element
from chia_rs.sized_ints import uint16, uint32
from chiapos import DiskProver

from chia._tests.plotting.util import get_test_plots
from chia._tests.util.misc import boolean_datacases
from chia._tests.util.time_out_assert import time_out_assert
from chia.plotting.cache import CURRENT_VERSION, CacheDataV1
from chia.plotting.manager import Cache, PlotManager
from chia.plotting.prover import V1Prover
from chia.plotting.util import (
PlotInfo,
PlotRefreshEvents,
Expand Down Expand Up @@ -743,6 +745,20 @@ async def test_recursive_plot_scan(environment: Environment) -> None:
await env.refresh_tester.run(expected_result)


@pytest.mark.anyio
async def test_disk_prover_from_bytes(environment: Environment):
env: Environment = environment
expected_result = PlotRefreshResult()
expected_result.loaded = env.dir_1.plot_info_list() # type: ignore[assignment]
expected_result.processed = len(env.dir_1)
add_plot_directory(env.root_path, str(env.dir_1.path))
await env.refresh_tester.run(expected_result)
_, plot_info = next(iter(env.refresh_tester.plot_manager.plots.items()))
recreated_prover = V1Prover(DiskProver.from_bytes(bytes(plot_info.prover)))
assert recreated_prover.get_id() == plot_info.prover.get_id()
assert recreated_prover.get_filename() == plot_info.prover.get_filename()


@boolean_datacases(name="follow_links", false="no_follow", true="follow")
@pytest.mark.anyio
async def test_recursive_plot_scan_symlinks(environment: Environment, follow_links: bool) -> None:
Expand Down
103 changes: 103 additions & 0 deletions chia/_tests/plotting/test_prover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from chia.plotting.prover import PlotVersion, V1Prover, V2Prover, get_prover_from_bytes, get_prover_from_file


class TestProver:
def test_v2_prover_init_with_nonexistent_file(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
assert prover.get_version() == PlotVersion.V2
assert prover.get_filename() == "/nonexistent/path/test.plot2"

def test_v2_prover_get_size_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_size()

def test_v2_prover_get_memo_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_memo()

def test_v2_prover_get_compression_level_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_compression_level()

def test_v2_prover_get_id_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_id()

def test_v2_prover_get_qualities_for_challenge_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_qualities_for_challenge(b"challenge")

def test_v2_prover_get_full_proof_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_full_proof(b"challenge", 0)

def test_v2_prover_bytes_raises_error(self) -> None:
prover = V2Prover("/nonexistent/path/test.plot2")
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
bytes(prover)

def test_v2_prover_from_bytes_raises_error(self) -> None:
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
V2Prover.from_bytes(b"test_data")

def test_get_prover_from_file(self) -> None:
prover = get_prover_from_file("/nonexistent/path/test.plot2")
assert prover.get_version() == PlotVersion.V2
with pytest.raises(NotImplementedError, match="V2 plot format is not yet implemented"):
prover.get_size()

def test_get_prover_from_file_with_plot1_still_works(self) -> None:
with tempfile.NamedTemporaryFile(suffix=".plot", delete=False) as f:
temp_path = f.name
try:
with pytest.raises(Exception) as exc_info:
get_prover_from_file(temp_path)
assert not isinstance(exc_info.value, NotImplementedError)
finally:
Path(temp_path).unlink()

def test_unsupported_file_extension_raises_value_error(self) -> None:
with pytest.raises(ValueError, match="Unsupported plot file"):
get_prover_from_file("/nonexistent/path/test.txt")


class TestV1Prover:
def test_v1_prover_get_version(self) -> None:
"""Test that V1Prover.get_version() returns PlotVersion.V1"""
mock_disk_prover = MagicMock()
prover = V1Prover(mock_disk_prover)
assert prover.get_version() == PlotVersion.V1


class TestGetProverFromBytes:
def test_get_prover_from_bytes_v2_plot(self) -> None:
with patch("chia.plotting.prover.V2Prover.from_bytes") as mock_v2_from_bytes:
mock_prover = MagicMock()
mock_v2_from_bytes.return_value = mock_prover
result = get_prover_from_bytes("test.plot2", b"test_data")
assert result == mock_prover

def test_get_prover_from_bytes_v1_plot(self) -> None:
with patch("chia.plotting.prover.DiskProver") as mock_disk_prover_class:
mock_disk_prover = MagicMock()
mock_disk_prover_class.from_bytes.return_value = mock_disk_prover
result = get_prover_from_bytes("test.plot", b"test_data")
assert isinstance(result, V1Prover)

def test_get_prover_from_bytes_unsupported_extension(self) -> None:
with pytest.raises(ValueError, match="Unsupported plot file"):
get_prover_from_bytes("test.txt", b"test_data")
4 changes: 2 additions & 2 deletions chia/harvester/harvester_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def new_signage_point_harvester(
loop = asyncio.get_running_loop()

def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, ProofOfSpace]]:
# Uses the DiskProver object to lookup qualities. This is a blocking call,
# Uses the Prover object to lookup qualities. This is a blocking call,
# so it should be run in a thread pool.
try:
plot_id = plot_info.prover.get_id()
Expand Down Expand Up @@ -218,7 +218,7 @@ def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32,
async def lookup_challenge(
filename: Path, plot_info: PlotInfo
) -> tuple[Path, list[harvester_protocol.NewProofOfSpace]]:
# Executes a DiskProverLookup in a thread pool, and returns responses
# Executes a ProverLookup in a thread pool, and returns responses
all_responses: list[harvester_protocol.NewProofOfSpace] = []
if self.harvester._shut_down:
return filename, []
Expand Down
14 changes: 9 additions & 5 deletions chia/plotting/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from dataclasses import dataclass, field
from math import ceil
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from chia.plotting.prover import ProverProtocol

from chia_rs import G1Element
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import uint16, uint64
from chiapos import DiskProver

from chia.plotting.prover import get_prover_from_bytes
from chia.plotting.util import parse_plot_info
from chia.types.blockchain_format.proof_of_space import generate_plot_public_key
from chia.util.streamable import Streamable, VersionedBlob, streamable
Expand Down Expand Up @@ -43,15 +46,15 @@ class CacheDataV1(Streamable):

@dataclass
class CacheEntry:
prover: DiskProver
prover: ProverProtocol
farmer_public_key: G1Element
pool_public_key: Optional[G1Element]
pool_contract_puzzle_hash: Optional[bytes32]
plot_public_key: G1Element
last_use: float

@classmethod
def from_disk_prover(cls, prover: DiskProver) -> CacheEntry:
def from_prover(cls, prover: ProverProtocol) -> CacheEntry:
(
pool_public_key_or_puzzle_hash,
farmer_public_key,
Expand Down Expand Up @@ -149,8 +152,9 @@ def load(self) -> None:
39: 44367,
}
for path, cache_entry in cache_data.entries:
prover: ProverProtocol = get_prover_from_bytes(path, cache_entry.prover_data)
new_entry = CacheEntry(
DiskProver.from_bytes(cache_entry.prover_data),
prover,
cache_entry.farmer_public_key,
cache_entry.pool_public_key,
cache_entry.pool_contract_puzzle_hash,
Expand Down
4 changes: 2 additions & 2 deletions chia/plotting/check_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Optional

from chia_rs import G1Element
from chia_rs.sized_ints import uint32
from chia_rs.sized_ints import uint8, uint32
from chiapos import Verifier

from chia.plotting.manager import PlotManager
Expand Down Expand Up @@ -133,7 +133,7 @@ def check_plots(
log.info("")
log.info("")
log.info(f"Starting to test each plot with {num} challenges each\n")
total_good_plots: Counter[str] = Counter()
total_good_plots: Counter[uint8] = Counter()
total_size = 0
bad_plots_list: list[Path] = []

Expand Down
7 changes: 4 additions & 3 deletions chia/plotting/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from typing import Any, Callable, Optional

from chia_rs import G1Element
from chiapos import DiskProver, decompressor_context_queue
from chiapos import decompressor_context_queue

from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size
from chia.plotting.cache import Cache, CacheEntry
from chia.plotting.prover import get_prover_from_file
from chia.plotting.util import (
HarvestingMode,
PlotInfo,
Expand Down Expand Up @@ -323,7 +324,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]:
cache_entry = self.cache.get(file_path)
cache_hit = cache_entry is not None
if not cache_hit:
prover = DiskProver(str(file_path))
prover = get_prover_from_file(str(file_path))

log.debug(f"process_file {file_path!s}")

Expand All @@ -343,7 +344,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]:
)
return None

cache_entry = CacheEntry.from_disk_prover(prover)
cache_entry = CacheEntry.from_prover(prover)
self.cache.update(file_path, cache_entry)

assert cache_entry is not None
Expand Down
Loading
Loading