Skip to content

Commit c1e5d28

Browse files
committed
support snapshot updates
1 parent 9363b3c commit c1e5d28

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

src/common/test_tools/plugin.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
from prometheus_client.metrics import MetricWrapperBase
66
from pyfakefs.fake_filesystem import FakeFilesystem
77

8-
from common.test_tools.types import AssertMetricFixture, SnapshotFixture
8+
from common.test_tools.types import AssertMetricFixture, Snapshot, SnapshotFixture
9+
10+
11+
def pytest_addoption(parser: pytest.Parser) -> None:
12+
group = parser.getgroup("snapshot")
13+
group.addoption(
14+
"--snapshot-update",
15+
action="store_true",
16+
help="Update snapshot files instead of testing against them.",
17+
)
918

1019

1120
def assert_metric_impl() -> Generator[AssertMetricFixture, None, None]:
@@ -72,10 +81,11 @@ def flagsmith_markers_marked(
7281

7382
@pytest.fixture
7483
def snapshot(request: pytest.FixtureRequest) -> SnapshotFixture:
75-
def _get_snapshot(name: str = "") -> str:
84+
for_update = request.config.getoption("--snapshot-update")
85+
86+
def _get_snapshot(name: str = "") -> Snapshot:
7687
snapshot_name = name or f"{request.node.name}.txt"
77-
return open(
78-
request.path.parent / f"snapshots/{snapshot_name}",
79-
).read()
88+
snapshot_path = request.path.parent / f"snapshots/{snapshot_name}"
89+
return Snapshot(snapshot_path, for_update=for_update)
8090

8191
return _get_snapshot

src/common/test_tools/types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from pathlib import Path
12
from typing import Protocol
23

4+
import pytest
5+
36

47
class AssertMetricFixture(Protocol):
58
def __call__(
@@ -12,4 +15,24 @@ def __call__(
1215

1316

1417
class SnapshotFixture(Protocol):
15-
def __call__(self, name: str = "") -> str: ...
18+
def __call__(self, name: str = "") -> "Snapshot": ...
19+
20+
21+
class Snapshot:
22+
def __init__(self, path: Path, for_update: bool) -> None:
23+
self.path = path
24+
self.content = open(path).read()
25+
self.for_update = for_update
26+
27+
def __eq__(self, other: object) -> bool:
28+
eq = self.content == other
29+
if not eq and isinstance(other, str) and self.for_update:
30+
with open(self.path, "w") as f:
31+
f.write(other)
32+
pytest.xfail(reason=f"Snapshot updated: {self.path}")
33+
return eq
34+
35+
def __str__(self) -> str:
36+
return self.content
37+
38+
__repr__ = __str__

0 commit comments

Comments
 (0)