Skip to content

Commit 404e3bf

Browse files
committed
Override default syrupy finish
1 parent e5a07da commit 404e3bf

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import requests_mock
3737
import respx
3838
from syrupy.assertion import SnapshotAssertion
39+
from syrupy.session import SnapshotSession
3940

4041
from homeassistant import block_async_io
4142
from homeassistant.exceptions import ServiceNotFound
@@ -92,7 +93,7 @@
9293
from homeassistant.util.json import json_loads
9394

9495
from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS
95-
from .syrupy import HomeAssistantSnapshotExtension
96+
from .syrupy import HomeAssistantSnapshotExtension, override_syrupy_finish
9697
from .typing import (
9798
ClientSessionGenerator,
9899
MockHAClientWebSocket,
@@ -1887,6 +1888,10 @@ async def _async_call(
18871888
yield calls
18881889

18891890

1891+
# Override default finish to detect unused snapshots despite xdist
1892+
SnapshotSession.finish = override_syrupy_finish
1893+
1894+
18901895
@pytest.fixture
18911896
def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion:
18921897
"""Return snapshot assertion fixture with the Home Assistant extension."""

tests/syrupy.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
from contextlib import suppress
66
import dataclasses
77
from enum import IntFlag
8+
import json
9+
import os
810
from pathlib import Path
911
from typing import Any
1012

1113
import attr
1214
import attrs
15+
import pytest
16+
from syrupy.constants import EXIT_STATUS_FAIL_UNUSED
17+
from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections
1318
from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension
1419
from syrupy.location import PyTestLocation
20+
from syrupy.report import SnapshotReport
21+
from syrupy.session import ItemStatus, SnapshotSession
1522
from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData
23+
from syrupy.utils import is_xdist_controller, is_xdist_worker
1624
import voluptuous as vol
1725
import voluptuous_serialize
1826

@@ -246,3 +254,166 @@ def dirname(cls, *, test_location: PyTestLocation) -> str:
246254
"""
247255
test_dir = Path(test_location.filepath).parent
248256
return str(test_dir.joinpath("snapshots"))
257+
258+
259+
class _customObject:
260+
"""Fake object."""
261+
262+
def __init__(self, collected_item: dict[str, str]) -> None:
263+
"""Initialise fake object."""
264+
self.__module__ = collected_item["modulename"]
265+
self.__name__ = collected_item["methodname"]
266+
267+
268+
class _customItem:
269+
"""Fake pytest.Item object."""
270+
271+
def __init__(self, collected_item: dict[str, str]) -> None:
272+
"""Initialise fake pytest.Item object."""
273+
self.nodeid = collected_item["nodeid"]
274+
self.name = collected_item["name"]
275+
self.path = Path(collected_item["path"])
276+
self.obj = _customObject(collected_item)
277+
278+
279+
def _dump_collections(collections: SnapshotCollections) -> dict[str, Any]:
280+
return {
281+
k: [c.name for c in v] for k, v in collections._snapshot_collections.items()
282+
}
283+
284+
285+
def _dump_report(
286+
report: SnapshotReport,
287+
collected_items: set[pytest.Item],
288+
selected_items: dict[str, ItemStatus],
289+
) -> dict[str, Any]:
290+
return {
291+
"discovered": _dump_collections(report.discovered),
292+
"created": _dump_collections(report.created),
293+
"failed": _dump_collections(report.failed),
294+
"matched": _dump_collections(report.matched),
295+
"updated": _dump_collections(report.updated),
296+
"used": _dump_collections(report.used),
297+
"_collected_items": [
298+
{
299+
"nodeid": c.nodeid,
300+
"name": c.name,
301+
"path": str(c.path),
302+
"modulename": c.obj.__module__,
303+
"methodname": c.obj.__name__,
304+
}
305+
for c in list(collected_items)
306+
],
307+
"_selected_items": {
308+
key: status.value for key, status in selected_items.items()
309+
},
310+
}
311+
312+
313+
def _update_collections(
314+
collections: SnapshotCollections, json_data: dict[str, list[str]]
315+
) -> None:
316+
if not json_data:
317+
return
318+
for location, names in json_data.items():
319+
snapshot_collection = SnapshotCollection(location=location)
320+
for name in names:
321+
snapshot_collection.add(Snapshot(name))
322+
collections.update(snapshot_collection)
323+
324+
325+
def _update_report(report: SnapshotReport, json_data: dict[str, Any]) -> None:
326+
_update_collections(report.discovered, json_data["discovered"])
327+
_update_collections(report.created, json_data["created"])
328+
_update_collections(report.failed, json_data["failed"])
329+
_update_collections(report.matched, json_data["matched"])
330+
_update_collections(report.updated, json_data["updated"])
331+
_update_collections(report.used, json_data["used"])
332+
for collected_item in json_data["_collected_items"]:
333+
custom_item = _customItem(collected_item)
334+
if not any(
335+
t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid
336+
for t in report.collected_items
337+
):
338+
report.collected_items.add(custom_item)
339+
for key, selected_item in json_data["_selected_items"].items():
340+
if key in report.selected_items:
341+
status = ItemStatus(selected_item)
342+
if status != ItemStatus.NOT_RUN:
343+
report.selected_items[key] = status
344+
else:
345+
report.selected_items[key] = ItemStatus(selected_item)
346+
347+
348+
def override_syrupy_finish(self: SnapshotSession) -> int:
349+
"""Override the finish method to allow for custom handling."""
350+
exitstatus = 0
351+
self.flush_snapshot_write_queue()
352+
self.report = SnapshotReport(
353+
base_dir=self.pytest_session.config.rootpath,
354+
collected_items=self._collected_items,
355+
selected_items=self._selected_items,
356+
assertions=self._assertions,
357+
options=self.pytest_session.config.option,
358+
)
359+
360+
if is_xdist_worker():
361+
xdist_worker = os.getenv("PYTEST_XDIST_WORKER")
362+
xdist_worker_count = os.getenv("PYTEST_XDIST_WORKER_COUNT")
363+
dump = _dump_report(self.report, self._collected_items, self._selected_items)
364+
# {
365+
# "_collected_items": [{"nodeid": c.nodeid, "name": c.name} for c in list(self._collected_items)],
366+
# "_selected_items": {key: status.value for key, status in self._selected_items.items()},
367+
# "_assertions": [{
368+
# "location": assertion.test_location.filepath,
369+
# "name": assertion.name,
370+
# "num_executions":assertion.num_executions,
371+
# } for assertion in self._assertions]
372+
# "_collected_items": self._dump(self.report[{"nodeid": c.nodeid, "name": c.name} for c in list(self._collected_items)],
373+
# }
374+
with open(
375+
"/workspaces/home-assistant-core/PYTEST_XDIST_WORKER_COUNT.txt",
376+
"w",
377+
encoding="utf-8",
378+
) as f:
379+
f.write(xdist_worker_count)
380+
with open(
381+
f"/workspaces/home-assistant-core/xdist_{xdist_worker}.txt",
382+
"w",
383+
encoding="utf-8",
384+
) as f:
385+
json.dump(dump, f, indent=2)
386+
return exitstatus
387+
if is_xdist_controller():
388+
return exitstatus
389+
390+
worker_count = None
391+
try:
392+
with open(
393+
"/workspaces/home-assistant-core/PYTEST_XDIST_WORKER_COUNT.txt",
394+
encoding="utf-8",
395+
) as f:
396+
worker_count = f.read()
397+
os.remove("/workspaces/home-assistant-core/PYTEST_XDIST_WORKER_COUNT.txt")
398+
except FileNotFoundError:
399+
pass
400+
401+
if worker_count:
402+
for i in range(int(worker_count)):
403+
with open(
404+
f"/workspaces/home-assistant-core/xdist_gw{i}.txt",
405+
encoding="utf-8",
406+
) as f:
407+
json_data = json.load(f)
408+
_update_report(self.report, json_data)
409+
os.remove(f"/workspaces/home-assistant-core/xdist_gw{i}.txt")
410+
411+
if self.report.num_unused:
412+
if self.update_snapshots:
413+
self.remove_unused_snapshots(
414+
unused_snapshot_collections=self.report.unused,
415+
used_snapshot_collections=self.report.used,
416+
)
417+
elif not self.warn_unused_snapshots:
418+
exitstatus |= EXIT_STATUS_FAIL_UNUSED
419+
return exitstatus

0 commit comments

Comments
 (0)