| 
5 | 5 | from contextlib import suppress  | 
6 | 6 | import dataclasses  | 
7 | 7 | from enum import IntFlag  | 
8 |  | -import json  | 
9 |  | -import os  | 
10 | 8 | from pathlib import Path  | 
11 | 9 | from typing import Any  | 
12 | 10 | 
 
  | 
13 | 11 | import attr  | 
14 | 12 | import attrs  | 
15 |  | -import pytest  | 
16 |  | -from syrupy.constants import EXIT_STATUS_FAIL_UNUSED  | 
17 |  | -from syrupy.data import Snapshot, SnapshotCollection, SnapshotCollections  | 
18 | 13 | from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension  | 
19 | 14 | from syrupy.location import PyTestLocation  | 
20 |  | -from syrupy.report import SnapshotReport  | 
21 |  | -from syrupy.session import ItemStatus, SnapshotSession  | 
22 | 15 | from syrupy.types import PropertyFilter, PropertyMatcher, PropertyPath, SerializableData  | 
23 |  | -from syrupy.utils import is_xdist_controller, is_xdist_worker  | 
24 | 16 | import voluptuous as vol  | 
25 | 17 | import voluptuous_serialize  | 
26 | 18 | 
 
  | 
@@ -254,157 +246,3 @@ def dirname(cls, *, test_location: PyTestLocation) -> str:  | 
254 | 246 |         """  | 
255 | 247 |         test_dir = Path(test_location.filepath).parent  | 
256 | 248 |         return str(test_dir.joinpath("snapshots"))  | 
257 |  | - | 
258 |  | - | 
259 |  | -# Classes and Methods to override default finish behavior in syrupy  | 
260 |  | -# This is needed to handle the xdist plugin in pytest  | 
261 |  | -# The default implementation does not handle the xdist plugin  | 
262 |  | -# and will not work correctly when running tests in parallel  | 
263 |  | -# with pytest-xdist.  | 
264 |  | -# Temporary workaround until it is finalised inside syrupy  | 
265 |  | -# See https://github.com/syrupy-project/syrupy/pull/901  | 
266 |  | - | 
267 |  | - | 
268 |  | -class _FakePytestObject:  | 
269 |  | -    """Fake object."""  | 
270 |  | - | 
271 |  | -    def __init__(self, collected_item: dict[str, str]) -> None:  | 
272 |  | -        """Initialise fake object."""  | 
273 |  | -        self.__module__ = collected_item["modulename"]  | 
274 |  | -        self.__name__ = collected_item["methodname"]  | 
275 |  | - | 
276 |  | - | 
277 |  | -class _FakePytestItem:  | 
278 |  | -    """Fake pytest.Item object."""  | 
279 |  | - | 
280 |  | -    def __init__(self, collected_item: dict[str, str]) -> None:  | 
281 |  | -        """Initialise fake pytest.Item object."""  | 
282 |  | -        self.nodeid = collected_item["nodeid"]  | 
283 |  | -        self.name = collected_item["name"]  | 
284 |  | -        self.path = Path(collected_item["path"])  | 
285 |  | -        self.obj = _FakePytestObject(collected_item)  | 
286 |  | - | 
287 |  | - | 
288 |  | -def _serialize_collections(collections: SnapshotCollections) -> dict[str, Any]:  | 
289 |  | -    return {  | 
290 |  | -        k: [c.name for c in v] for k, v in collections._snapshot_collections.items()  | 
291 |  | -    }  | 
292 |  | - | 
293 |  | - | 
294 |  | -def _serialize_report(  | 
295 |  | -    report: SnapshotReport,  | 
296 |  | -    collected_items: set[pytest.Item],  | 
297 |  | -    selected_items: dict[str, ItemStatus],  | 
298 |  | -) -> dict[str, Any]:  | 
299 |  | -    return {  | 
300 |  | -        "discovered": _serialize_collections(report.discovered),  | 
301 |  | -        "created": _serialize_collections(report.created),  | 
302 |  | -        "failed": _serialize_collections(report.failed),  | 
303 |  | -        "matched": _serialize_collections(report.matched),  | 
304 |  | -        "updated": _serialize_collections(report.updated),  | 
305 |  | -        "used": _serialize_collections(report.used),  | 
306 |  | -        "_collected_items": [  | 
307 |  | -            {  | 
308 |  | -                "nodeid": c.nodeid,  | 
309 |  | -                "name": c.name,  | 
310 |  | -                "path": str(c.path),  | 
311 |  | -                "modulename": c.obj.__module__,  | 
312 |  | -                "methodname": c.obj.__name__,  | 
313 |  | -            }  | 
314 |  | -            for c in list(collected_items)  | 
315 |  | -        ],  | 
316 |  | -        "_selected_items": {  | 
317 |  | -            key: status.value for key, status in selected_items.items()  | 
318 |  | -        },  | 
319 |  | -    }  | 
320 |  | - | 
321 |  | - | 
322 |  | -def _merge_serialized_collections(  | 
323 |  | -    collections: SnapshotCollections, json_data: dict[str, list[str]]  | 
324 |  | -) -> None:  | 
325 |  | -    if not json_data:  | 
326 |  | -        return  | 
327 |  | -    for location, names in json_data.items():  | 
328 |  | -        snapshot_collection = SnapshotCollection(location=location)  | 
329 |  | -        for name in names:  | 
330 |  | -            snapshot_collection.add(Snapshot(name))  | 
331 |  | -        collections.update(snapshot_collection)  | 
332 |  | - | 
333 |  | - | 
334 |  | -def _merge_serialized_report(report: SnapshotReport, json_data: dict[str, Any]) -> None:  | 
335 |  | -    _merge_serialized_collections(report.discovered, json_data["discovered"])  | 
336 |  | -    _merge_serialized_collections(report.created, json_data["created"])  | 
337 |  | -    _merge_serialized_collections(report.failed, json_data["failed"])  | 
338 |  | -    _merge_serialized_collections(report.matched, json_data["matched"])  | 
339 |  | -    _merge_serialized_collections(report.updated, json_data["updated"])  | 
340 |  | -    _merge_serialized_collections(report.used, json_data["used"])  | 
341 |  | -    for collected_item in json_data["_collected_items"]:  | 
342 |  | -        custom_item = _FakePytestItem(collected_item)  | 
343 |  | -        if not any(  | 
344 |  | -            t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid  | 
345 |  | -            for t in report.collected_items  | 
346 |  | -        ):  | 
347 |  | -            report.collected_items.add(custom_item)  | 
348 |  | -    for key, selected_item in json_data["_selected_items"].items():  | 
349 |  | -        if key in report.selected_items:  | 
350 |  | -            status = ItemStatus(selected_item)  | 
351 |  | -            if status != ItemStatus.NOT_RUN:  | 
352 |  | -                report.selected_items[key] = status  | 
353 |  | -        else:  | 
354 |  | -            report.selected_items[key] = ItemStatus(selected_item)  | 
355 |  | - | 
356 |  | - | 
357 |  | -def override_syrupy_finish(self: SnapshotSession) -> int:  | 
358 |  | -    """Override the finish method to allow for custom handling."""  | 
359 |  | -    exitstatus = 0  | 
360 |  | -    self.flush_snapshot_write_queue()  | 
361 |  | -    self.report = SnapshotReport(  | 
362 |  | -        base_dir=self.pytest_session.config.rootpath,  | 
363 |  | -        collected_items=self._collected_items,  | 
364 |  | -        selected_items=self._selected_items,  | 
365 |  | -        assertions=self._assertions,  | 
366 |  | -        options=self.pytest_session.config.option,  | 
367 |  | -    )  | 
368 |  | - | 
369 |  | -    if is_xdist_worker():  | 
370 |  | -        with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f:  | 
371 |  | -            f.write(os.getenv("PYTEST_XDIST_WORKER_COUNT"))  | 
372 |  | -        with open(  | 
373 |  | -            f".pytest_syrupy_{os.getenv("PYTEST_XDIST_WORKER")}_result",  | 
374 |  | -            "w",  | 
375 |  | -            encoding="utf-8",  | 
376 |  | -        ) as f:  | 
377 |  | -            json.dump(  | 
378 |  | -                _serialize_report(  | 
379 |  | -                    self.report, self._collected_items, self._selected_items  | 
380 |  | -                ),  | 
381 |  | -                f,  | 
382 |  | -                indent=2,  | 
383 |  | -            )  | 
384 |  | -        return exitstatus  | 
385 |  | -    if is_xdist_controller():  | 
386 |  | -        return exitstatus  | 
387 |  | - | 
388 |  | -    worker_count = None  | 
389 |  | -    try:  | 
390 |  | -        with open(".pytest_syrupy_worker_count", encoding="utf-8") as f:  | 
391 |  | -            worker_count = f.read()  | 
392 |  | -        os.remove(".pytest_syrupy_worker_count")  | 
393 |  | -    except FileNotFoundError:  | 
394 |  | -        pass  | 
395 |  | - | 
396 |  | -    if worker_count:  | 
397 |  | -        for i in range(int(worker_count)):  | 
398 |  | -            with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f:  | 
399 |  | -                _merge_serialized_report(self.report, json.load(f))  | 
400 |  | -            os.remove(f".pytest_syrupy_gw{i}_result")  | 
401 |  | - | 
402 |  | -    if self.report.num_unused:  | 
403 |  | -        if self.update_snapshots:  | 
404 |  | -            self.remove_unused_snapshots(  | 
405 |  | -                unused_snapshot_collections=self.report.unused,  | 
406 |  | -                used_snapshot_collections=self.report.used,  | 
407 |  | -            )  | 
408 |  | -        elif not self.warn_unused_snapshots:  | 
409 |  | -            exitstatus |= EXIT_STATUS_FAIL_UNUSED  | 
410 |  | -    return exitstatus  | 
0 commit comments