Skip to content

Commit 17b0b34

Browse files
committed
refactor
1 parent bffc732 commit 17b0b34

File tree

7 files changed

+66
-85
lines changed

7 files changed

+66
-85
lines changed

chia/_tests/plot_sync/test_sync_simulated.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ async def run(
6969
initial: bool,
7070
) -> None:
7171
for plot_info in loaded:
72-
assert plot_info.prover.get_filepath() not in self.plots
72+
assert Path(plot_info.prover.get_filename()) not in self.plots
7373
for plot_info in removed:
74-
assert plot_info.prover.get_filepath() in self.plots
74+
assert Path(plot_info.prover.get_filename()) in self.plots
7575

7676
self.invalid = invalid
7777
self.keys_missing = keys_missing
7878
self.duplicates = duplicates
7979

80-
removed_paths: list[Path] = [p.prover.get_filepath() for p in removed] if removed is not None else []
81-
invalid_dict: dict[Path, int] = {p.prover.get_filepath(): 0 for p in self.invalid}
82-
keys_missing_set: set[Path] = {p.prover.get_filepath() for p in self.keys_missing}
80+
removed_paths: list[Path] = [Path(p.prover.get_filename()) for p in removed] if removed is not None else []
81+
invalid_dict: dict[Path, int] = {Path(p.prover.get_filename()): 0 for p in self.invalid}
82+
keys_missing_set: set[Path] = {Path(p.prover.get_filename()) for p in self.keys_missing}
8383
duplicates_set: set[str] = {p.prover.get_filename() for p in self.duplicates}
8484

8585
# Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them
@@ -91,7 +91,7 @@ async def run(
9191
# Inject duplicated plots into `PlotManager` of the harvester so that the callback calls below can use them
9292
# to sync them to the farmer.
9393
for plot_info in loaded:
94-
plot_path = Path(plot_info.prover.get_filepath())
94+
plot_path = Path(plot_info.prover.get_filename())
9595
self.harvester.plot_manager.plot_filename_paths[plot_path.name] = (str(plot_path.parent), set())
9696
for duplicate in duplicates_set:
9797
plot_path = Path(duplicate)
@@ -123,9 +123,9 @@ async def sync_done() -> bool:
123123
await time_out_assert(60, sync_done)
124124

125125
for plot_info in loaded:
126-
self.plots[plot_info.prover.get_filepath()] = plot_info
126+
self.plots[Path(plot_info.prover.get_filename())] = plot_info
127127
for plot_info in removed:
128-
del self.plots[plot_info.prover.get_filepath()]
128+
del self.plots[Path(plot_info.prover.get_filename())]
129129

130130
def validate_plot_sync(self) -> None:
131131
assert len(self.plots) == len(self.plot_sync_receiver.plots())
@@ -417,16 +417,16 @@ async def test_sync_reset_cases(
417417
# Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a
418418
# fresh sync of all available data of the plot manager
419419
for plot_info in plots[0:10]:
420-
test_data.plots[plot_info.prover.get_filepath()] = plot_info
420+
test_data.plots[Path(plot_info.prover.get_filename())] = plot_info
421421
plot_manager.plots = test_data.plots
422422
test_data.invalid = plots[10:20]
423423
test_data.keys_missing = plots[20:30]
424424
test_data.plot_sync_receiver.simulate_error = simulate_error # type: ignore[attr-defined]
425425
sender: Sender = test_runner.test_data[0].plot_sync_sender
426426
started_sync_id: uint64 = uint64(0)
427427

428-
plot_manager.failed_to_open_filenames = {p.prover.get_filepath(): 0 for p in test_data.invalid}
429-
plot_manager.no_key_filenames = {p.prover.get_filepath() for p in test_data.keys_missing}
428+
plot_manager.failed_to_open_filenames = {Path(p.prover.get_filename()): 0 for p in test_data.invalid}
429+
plot_manager.no_key_filenames = {Path(p.prover.get_filename()) for p in test_data.keys_missing}
430430

431431
async def wait_for_reset() -> bool:
432432
assert started_sync_id != 0

chia/_tests/plotting/test_plot_manager.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@
4242
class MockDiskProver:
4343
filename: str
4444

45-
def get_filepath(self) -> Path:
46-
return Path(self.filename)
47-
4845
def get_filename(self) -> str:
4946
return self.filename
5047

@@ -115,7 +112,7 @@ def refresh_callback(self, event: PlotRefreshEvents, refresh_result: PlotRefresh
115112
for value in actual_value:
116113
if type(value) is PlotInfo:
117114
for plot_info in expected_list:
118-
if plot_info.prover.get_filename() == value.prover.get_filepath():
115+
if plot_info.prover.get_filename() == value.prover.get_filename():
119116
values_found += 1
120117
continue
121118
else:
@@ -507,7 +504,7 @@ async def test_plot_info_caching(environment, bt):
507504
await refresh_tester.run(expected_result)
508505
for path, plot_info in env.refresh_tester.plot_manager.plots.items():
509506
assert path in plot_manager.plots
510-
assert plot_manager.plots[path].prover.get_filepath() == plot_info.prover.get_filepath()
507+
assert plot_manager.plots[path].prover.get_filename() == plot_info.prover.get_filename()
511508
assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id()
512509
assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo()
513510
assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size()
@@ -598,7 +595,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None:
598595
test_cache.load()
599596
assert len(test_cache) == len(expected)
600597
for plot_info in expected:
601-
assert test_cache.get(Path(plot_info.prover.get_filepath())) is not None
598+
assert test_cache.get(Path(plot_info.prover.get_filename())) is not None
602599

603600
# Modify two entries, with and without memo modification, they both should remain in the cache after load
604601
modify_cache_entry(0, 1500, modify_memo=False)
@@ -617,7 +614,7 @@ def assert_cache(expected: list[MockPlotInfo]) -> None:
617614
# Write the modified cache entries to the file
618615
cache_path.write_bytes(bytes(VersionedBlob(uint16(CURRENT_VERSION), bytes(cache_data))))
619616
# And now test that plots in invalid_entries are not longer loaded
620-
assert_cache([plot_info for plot_info in plot_infos if str(plot_info.prover.get_filepath()) not in invalid_entries])
617+
assert_cache([plot_info for plot_info in plot_infos if plot_info.prover.get_filename() not in invalid_entries])
621618

622619

623620
@pytest.mark.anyio

chia/_tests/plotting/test_prover.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
import tempfile
44
from pathlib import Path
5+
from unittest.mock import MagicMock, patch
56

67
import pytest
78

8-
from chia.plotting.prover import PlotVersion, V2Prover, get_prover_from_file
9+
from chia.plotting.prover import PlotVersion, V1Prover, V2Prover, get_prover_from_bytes, get_prover_from_file
910

1011

1112
class TestProver:
1213
def test_v2_prover_init_with_nonexistent_file(self) -> None:
1314
prover = V2Prover("/nonexistent/path/test.plot2")
1415
assert prover.get_version() == PlotVersion.V2
15-
assert prover.get_filepath() == Path("/nonexistent/path/test.plot2")
1616
assert prover.get_filename() == "/nonexistent/path/test.plot2"
1717

1818
def test_v2_prover_get_size_raises_error(self) -> None:
@@ -71,6 +71,25 @@ def test_get_prover_from_file_with_plot1_still_works(self) -> None:
7171
Path(temp_path).unlink()
7272

7373
def test_unsupported_file_extension_raises_value_error(self) -> None:
74-
"""Test that unsupported file extensions raise ValueError"""
7574
with pytest.raises(ValueError, match="Unsupported plot file"):
7675
get_prover_from_file("/nonexistent/path/test.txt")
76+
77+
78+
class TestGetProverFromBytes:
79+
def test_get_prover_from_bytes_v2_plot(self) -> None:
80+
with patch("chia.plotting.prover.V2Prover.from_bytes") as mock_v2_from_bytes:
81+
mock_prover = MagicMock()
82+
mock_v2_from_bytes.return_value = mock_prover
83+
result = get_prover_from_bytes("test.plot2", b"test_data")
84+
assert result == mock_prover
85+
86+
def test_get_prover_from_bytes_v1_plot(self) -> None:
87+
with patch("chia.plotting.prover.DiskProver") as mock_disk_prover_class:
88+
mock_disk_prover = MagicMock()
89+
mock_disk_prover_class.from_bytes.return_value = mock_disk_prover
90+
result = get_prover_from_bytes("test.plot", b"test_data")
91+
assert isinstance(result, V1Prover)
92+
93+
def test_get_prover_from_bytes_unsupported_extension(self) -> None:
94+
with pytest.raises(ValueError, match="Unsupported plot file"):
95+
get_prover_from_bytes("test.txt", b"test_data")

chia/harvester/harvester_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def new_signage_point_harvester(
9797
loop = asyncio.get_running_loop()
9898

9999
def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32, ProofOfSpace]]:
100-
# Uses the DiskProver object to lookup qualities. This is a blocking call,
100+
# Uses the Prover object to lookup qualities. This is a blocking call,
101101
# so it should be run in a thread pool.
102102
try:
103103
plot_id = plot_info.prover.get_id()
@@ -218,7 +218,7 @@ def blocking_lookup(filename: Path, plot_info: PlotInfo) -> list[tuple[bytes32,
218218
async def lookup_challenge(
219219
filename: Path, plot_info: PlotInfo
220220
) -> tuple[Path, list[harvester_protocol.NewProofOfSpace]]:
221-
# Executes a DiskProverLookup in a thread pool, and returns responses
221+
# Executes a ProverLookup in a thread pool, and returns responses
222222
all_responses: list[harvester_protocol.NewProofOfSpace] = []
223223
if self.harvester._shut_down:
224224
return filename, []

chia/plotting/cache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class CacheEntry:
5555

5656
@classmethod
5757
def from_prover(cls, prover: ProverProtocol) -> CacheEntry:
58-
"""Create CacheEntry from prover"""
5958
(
6059
pool_public_key_or_puzzle_hash,
6160
farmer_public_key,

chia/plotting/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,10 @@ def process_file(file_path: Path) -> Optional[PlotInfo]:
386386
with self.plot_filename_paths_lock:
387387
paths: Optional[tuple[str, set[str]]] = self.plot_filename_paths.get(file_path.name)
388388
if paths is None:
389-
paths = (str(Path(cache_entry.prover.get_filepath()).parent), set())
389+
paths = (str(Path(cache_entry.prover.get_filename()).parent), set())
390390
self.plot_filename_paths[file_path.name] = paths
391391
else:
392-
paths[1].add(str(Path(cache_entry.prover.get_filepath()).parent))
392+
paths[1].add(str(Path(cache_entry.prover.get_filename()).parent))
393393
log.warning(f"Have multiple copies of the plot {file_path.name} in {[paths[0], *paths[1]]}.")
394394
return None
395395

@@ -423,7 +423,7 @@ def process_file(file_path: Path) -> Optional[PlotInfo]:
423423
plots_refreshed: dict[Path, PlotInfo] = {}
424424
for new_plot in executor.map(process_file, plot_paths):
425425
if new_plot is not None:
426-
plots_refreshed[Path(new_plot.prover.get_filepath())] = new_plot
426+
plots_refreshed[Path(new_plot.prover.get_filename())] = new_plot
427427
self.plots.update(plots_refreshed)
428428

429429
result.duration = time.time() - start_time

chia/plotting/prover.py

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
43
from enum import IntEnum
5-
from pathlib import Path
6-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, ClassVar, Protocol, cast
75

86
from chia_rs.sized_bytes import bytes32
97
from chia_rs.sized_ints import uint8
@@ -20,62 +18,30 @@ class PlotVersion(IntEnum):
2018
V2 = 2
2119

2220

23-
class ProverProtocol(ABC):
24-
@abstractmethod
25-
def get_filepath(self) -> Path:
26-
"""Returns the filename for the plot"""
27-
28-
@abstractmethod
29-
def get_filename(self) -> str:
30-
"""Returns the filename string for the plot"""
31-
32-
@abstractmethod
33-
def get_size(self) -> uint8:
34-
"""Returns the k size of the plot"""
35-
36-
@abstractmethod
37-
def get_memo(self) -> bytes:
38-
"""Returns the memo"""
39-
40-
@abstractmethod
41-
def get_compression_level(self) -> uint8:
42-
"""Returns the compression level"""
43-
44-
@abstractmethod
45-
def get_version(self) -> PlotVersion:
46-
"""Returns the plot version"""
47-
48-
@abstractmethod
49-
def __bytes__(self) -> bytes:
50-
"""Returns the prover bytes"""
51-
52-
@abstractmethod
53-
def get_id(self) -> bytes32:
54-
"""Returns the plot ID"""
55-
56-
@abstractmethod
57-
def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]:
58-
"""Returns the qualities for a given challenge"""
59-
60-
@abstractmethod
61-
def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes:
62-
"""Returns the full proof for a given challenge and index"""
21+
class ProverProtocol(Protocol):
22+
def get_filename(self) -> str: ...
23+
def get_size(self) -> uint8: ...
24+
def get_memo(self) -> bytes: ...
25+
def get_compression_level(self) -> uint8: ...
26+
def get_version(self) -> PlotVersion: ...
27+
def __bytes__(self) -> bytes: ...
28+
def get_id(self) -> bytes32: ...
29+
def get_qualities_for_challenge(self, challenge: bytes32) -> list[bytes32]: ...
30+
def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes: ...
6331

6432
@classmethod
65-
@abstractmethod
66-
def from_bytes(cls, data: bytes) -> ProverProtocol:
67-
"""Create a prover from serialized bytes"""
33+
def from_bytes(cls, data: bytes) -> ProverProtocol: ...
6834

6935

70-
class V2Prover(ProverProtocol):
71-
"""V2 Plot Prover stubb"""
36+
class V2Prover:
37+
"""Placeholder for future V2 plot format support"""
38+
39+
if TYPE_CHECKING:
40+
_protocol_check: ClassVar[ProverProtocol] = cast("V2Prover", None)
7241

7342
def __init__(self, filename: str):
7443
self._filename = filename
7544

76-
def get_filepath(self) -> Path:
77-
return Path(self._filename)
78-
7945
def get_filename(self) -> str:
8046
return str(self._filename)
8147

@@ -102,29 +68,29 @@ def get_id(self) -> bytes32:
10268
# TODO: Extract plot ID from V2 plot file
10369
raise NotImplementedError("V2 plot format is not yet implemented")
10470

105-
def get_qualities_for_challenge(self, _challenge: bytes) -> list[bytes32]:
71+
def get_qualities_for_challenge(self, challenge: bytes) -> list[bytes32]:
10672
# TODO: todo_v2_plots Implement plot quality lookup
10773
raise NotImplementedError("V2 plot format is not yet implemented")
10874

109-
def get_full_proof(self, _challenge: bytes, _index: int, _parallel_read: bool = True) -> bytes:
75+
def get_full_proof(self, challenge: bytes, index: int, parallel_read: bool = True) -> bytes:
11076
# TODO: todo_v2_plots Implement plot proof generation
11177
raise NotImplementedError("V2 plot format is not yet implemented")
11278

11379
@classmethod
114-
def from_bytes(cls, _data: bytes) -> V2Prover:
80+
def from_bytes(cls, data: bytes) -> V2Prover:
11581
# TODO: todo_v2_plots Implement prover deserialization from cache
11682
raise NotImplementedError("V2 plot format is not yet implemented")
11783

11884

119-
class V1Prover(ProverProtocol):
85+
class V1Prover:
12086
"""Wrapper for existing DiskProver to implement ProverProtocol"""
12187

88+
if TYPE_CHECKING:
89+
_protocol_check: ClassVar[ProverProtocol] = cast("V1Prover", None)
90+
12291
def __init__(self, disk_prover: DiskProver) -> None:
12392
self._disk_prover = disk_prover
12493

125-
def get_filepath(self) -> Path:
126-
return Path(self._disk_prover.get_filename())
127-
12894
def get_filename(self) -> str:
12995
return str(self._disk_prover.get_filename())
13096

0 commit comments

Comments
 (0)