Skip to content

Commit 59f0a9a

Browse files
committed
Override default syrupy finish
1 parent a6f8fae commit 59f0a9a

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import requests_mock
3737
import respx
3838
from syrupy.assertion import SnapshotAssertion
39+
from syrupy.session import SnapshotSession
3940

4041
from homeassistant import block_async_io
4142
from homeassistant.exceptions import ServiceNotFound
@@ -92,7 +93,7 @@
9293
from homeassistant.util.json import json_loads
9394

9495
from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
95-
from .syrupy import HomeAssistantSnapshotExtension
96+
from .syrupy import HomeAssistantSnapshotExtension, override_syrupy_finish
9697
from .typing import (
9798
ClientSessionGenerator,
9899
MockHAClientWebSocket,
@@ -149,6 +150,11 @@ def pytest_configure(config: pytest.Config) -> None:
149150
if config.getoption("verbose") > 0:
150151
logging.getLogger().setLevel(logging.DEBUG)
151152

153+
# Override default finish to detect unused snapshots despite xdist
154+
# Temporary workaround until it is finalised inside syrupy
155+
# See https://github.com/syrupy-project/syrupy/pull/901
156+
SnapshotSession.finish = override_syrupy_finish
157+
152158

153159
def pytest_runtest_setup() -> None:
154160
"""Prepare pytest_socket and freezegun.

tests/syrupy.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
from contextlib import suppress
66
import dataclasses
77
from enum import IntFlag
8+
import json
9+
import os
810
from pathlib import Path
911
from typing import Any
1012

1113
import attr
1214
import attrs
15+
import pytest
16+
from syrupy.constants import EXIT_STATUS_FAIL_UNUSED
17+
from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections
1318
from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension
1419
from syrupy.location import PyTestLocation
20+
from syrupy.report import SnapshotReport
21+
from syrupy.session import ItemStatus, SnapshotSession
1522
from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData
23+
from syrupy.utils import is_xdist_controller, is_xdist_worker
1624
import voluptuous as vol
1725
import voluptuous_serialize
1826

@@ -246,3 +254,157 @@ def dirname(cls, *, test_location: PyTestLocation) -> str:
246254
"""
247255
test_dir = Path(test_location.filepath).parent
248256
return str(test_dir.joinpath("snapshots"))
257+
258+
259+
# Classes and Methods to override default finish behavior in syrupy
260+
# This is needed to handle the xdist plugin in pytest
261+
# The default implementation does not handle the xdist plugin
262+
# and will not work correctly when running tests in parallel
263+
# with pytest-xdist.
264+
# Temporary workaround until it is finalised inside syrupy
265+
# See https://github.com/syrupy-project/syrupy/pull/901
266+
267+
268+
class _FakePytestObject:
269+
"""Fake object."""
270+
271+
def __init__(self, collected_item: dict[str, str]) -> None:
272+
"""Initialise fake object."""
273+
self.__module__ = collected_item["modulename"]
274+
self.__name__ = collected_item["methodname"]
275+
276+
277+
class _FakePytestItem:
278+
"""Fake pytest.Item object."""
279+
280+
def __init__(self, collected_item: dict[str, str]) -> None:
281+
"""Initialise fake pytest.Item object."""
282+
self.nodeid = collected_item["nodeid"]
283+
self.name = collected_item["name"]
284+
self.path = Path(collected_item["path"])
285+
self.obj = _FakePytestObject(collected_item)
286+
287+
288+
def _serialize_collections(collections: SnapshotCollections) -> dict[str, Any]:
289+
return {
290+
k: [c.name for c in v] for k, v in collections._snapshot_collections.items()
291+
}
292+
293+
294+
def _serialize_report(
295+
report: SnapshotReport,
296+
collected_items: set[pytest.Item],
297+
selected_items: dict[str, ItemStatus],
298+
) -> dict[str, Any]:
299+
return {
300+
"discovered": _serialize_collections(report.discovered),
301+
"created": _serialize_collections(report.created),
302+
"failed": _serialize_collections(report.failed),
303+
"matched": _serialize_collections(report.matched),
304+
"updated": _serialize_collections(report.updated),
305+
"used": _serialize_collections(report.used),
306+
"_collected_items": [
307+
{
308+
"nodeid": c.nodeid,
309+
"name": c.name,
310+
"path": str(c.path),
311+
"modulename": c.obj.__module__,
312+
"methodname": c.obj.__name__,
313+
}
314+
for c in list(collected_items)
315+
],
316+
"_selected_items": {
317+
key: status.value for key, status in selected_items.items()
318+
},
319+
}
320+
321+
322+
def _merge_serialized_collections(
323+
collections: SnapshotCollections, json_data: dict[str, list[str]]
324+
) -> None:
325+
if not json_data:
326+
return
327+
for location, names in json_data.items():
328+
snapshot_collection = SnapshotCollection(location=location)
329+
for name in names:
330+
snapshot_collection.add(Snapshot(name))
331+
collections.update(snapshot_collection)
332+
333+
334+
def _merge_serialized_report(report: SnapshotReport, json_data: dict[str, Any]) -> None:
335+
_merge_serialized_collections(report.discovered, json_data["discovered"])
336+
_merge_serialized_collections(report.created, json_data["created"])
337+
_merge_serialized_collections(report.failed, json_data["failed"])
338+
_merge_serialized_collections(report.matched, json_data["matched"])
339+
_merge_serialized_collections(report.updated, json_data["updated"])
340+
_merge_serialized_collections(report.used, json_data["used"])
341+
for collected_item in json_data["_collected_items"]:
342+
custom_item = _FakePytestItem(collected_item)
343+
if not any(
344+
t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid
345+
for t in report.collected_items
346+
):
347+
report.collected_items.add(custom_item)
348+
for key, selected_item in json_data["_selected_items"].items():
349+
if key in report.selected_items:
350+
status = ItemStatus(selected_item)
351+
if status != ItemStatus.NOT_RUN:
352+
report.selected_items[key] = status
353+
else:
354+
report.selected_items[key] = ItemStatus(selected_item)
355+
356+
357+
def override_syrupy_finish(self: SnapshotSession) -> int:
358+
"""Override the finish method to allow for custom handling."""
359+
exitstatus = 0
360+
self.flush_snapshot_write_queue()
361+
self.report = SnapshotReport(
362+
base_dir=self.pytest_session.config.rootpath,
363+
collected_items=self._collected_items,
364+
selected_items=self._selected_items,
365+
assertions=self._assertions,
366+
options=self.pytest_session.config.option,
367+
)
368+
369+
if is_xdist_worker():
370+
with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f:
371+
f.write(os.getenv("PYTEST_XDIST_WORKER_COUNT"))
372+
with open(
373+
f".pytest_syrupy_{os.getenv("PYTEST_XDIST_WORKER")}_result",
374+
"w",
375+
encoding="utf-8",
376+
) as f:
377+
json.dump(
378+
_serialize_report(
379+
self.report, self._collected_items, self._selected_items
380+
),
381+
f,
382+
indent=2,
383+
)
384+
return exitstatus
385+
if is_xdist_controller():
386+
return exitstatus
387+
388+
worker_count = None
389+
try:
390+
with open(".pytest_syrupy_worker_count", encoding="utf-8") as f:
391+
worker_count = f.read()
392+
os.remove(".pytest_syrupy_worker_count")
393+
except FileNotFoundError:
394+
pass
395+
396+
if worker_count:
397+
for i in range(int(worker_count)):
398+
with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f:
399+
_merge_serialized_report(self.report, json.load(f))
400+
os.remove(f".pytest_syrupy_gw{i}_result")
401+
402+
if self.report.num_unused:
403+
if self.update_snapshots:
404+
self.remove_unused_snapshots(
405+
unused_snapshot_collections=self.report.unused,
406+
used_snapshot_collections=self.report.used,
407+
)
408+
elif not self.warn_unused_snapshots:
409+
exitstatus |= EXIT_STATUS_FAIL_UNUSED
410+
return exitstatus

0 commit comments

Comments
 (0)