|
4 | 4 | """Fixtures for wall clock timer tests.""" |
5 | 5 |
|
6 | 6 | import asyncio |
7 | | -from collections.abc import Iterator |
8 | | -from datetime import datetime |
| 7 | +from collections.abc import Callable, Iterator, Sequence |
| 8 | +from datetime import datetime, timedelta |
9 | 9 | from unittest.mock import AsyncMock, MagicMock, patch |
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 | from frequenz.core.datetime import UNIX_EPOCH |
13 | 13 |
|
| 14 | +from frequenz.sdk.timeseries._resampling._wall_clock_timer import ClocksInfo, TickInfo |
14 | 15 |
|
15 | 16 | # Some of the utils do assertions and we want them to be rewritten by pytest for better |
16 | 17 | # error messages |
@@ -49,3 +50,106 @@ async def time_driver(datetime_mock: MagicMock) -> TimeDriver: |
49 | 50 | return TimeDriver( |
50 | 51 | datetime_mock=datetime_mock, |
51 | 52 | ) |
| 53 | + |
| 54 | + |
| 55 | +def pytest_assertrepr_compare(op: str, left: object, right: object) -> list[str] | None: |
| 56 | + """Provide custom, readable error reports for TickInfo comparisons.""" |
| 57 | + # We only care about == comparisons involving our TickInfo objects, returning None |
| 58 | + # makes pytest fall back to its default comparison behavior. |
| 59 | + if op != "==" or not isinstance(left, TickInfo) or not isinstance(right, TickInfo): |
| 60 | + return None |
| 61 | + |
| 62 | + # Helper function to format values for readability |
| 63 | + def format_val(val: object) -> str: |
| 64 | + # For our time-based types, use str() for readability instead of repr() |
| 65 | + if isinstance(val, (datetime, timedelta)): |
| 66 | + return str(val) |
| 67 | + # For our approx objects and others, the default repr is already good. |
| 68 | + return repr(val) |
| 69 | + |
| 70 | + errors = _compare_tick_info_objects(left, right, format_val) |
| 71 | + # If the comparison was actually successful (no errors), let pytest handle it |
| 72 | + if not errors: |
| 73 | + return None |
| 74 | + |
| 75 | + # Format the final error message |
| 76 | + report = ["Comparing TickInfo objects:"] |
| 77 | + report.append(" Differing attributes:") |
| 78 | + report.append(f" {list(errors.keys())!r}") |
| 79 | + report.append("") |
| 80 | + |
| 81 | + for field, diff in errors.items(): |
| 82 | + report.append(f" Drill down into differing attribute '{field}':") |
| 83 | + # The diff can be a simple tuple of (left, right) values or a list of |
| 84 | + # strings for nested diffs |
| 85 | + match diff: |
| 86 | + case list(): |
| 87 | + report.extend(f" {line}" for line in diff) |
| 88 | + case (left_val, right_val): |
| 89 | + report.append(f" - {format_val(left_val)}") |
| 90 | + report.append(f" + {format_val(right_val)}") |
| 91 | + case _: |
| 92 | + assert False, f"Unexpected diff type: {type(diff)}" |
| 93 | + |
| 94 | + return report |
| 95 | + |
| 96 | + |
| 97 | +# We are dealing these fields dynamically, so we make a sanity check here to make sure |
| 98 | +# if something changes, we can catch it early instead of getting some cryptic errors |
| 99 | +# later, deep in the code. |
| 100 | +assert set(TickInfo.__dataclass_fields__.keys()) == { |
| 101 | + "expected_tick_time", |
| 102 | + "sleep_infos", |
| 103 | +} |
| 104 | + |
| 105 | + |
| 106 | +def _compare_tick_info_objects( |
| 107 | + left: TickInfo, right: TickInfo, format_val: Callable[[object], str] |
| 108 | +) -> dict[str, object]: |
| 109 | + """Compare two TickInfo objects and return a dictionary of differences.""" |
| 110 | + errors: dict[str, object] = {} |
| 111 | + |
| 112 | + # 1. Compare top-level fields |
| 113 | + if left.expected_tick_time != right.expected_tick_time: |
| 114 | + errors["expected_tick_time"] = ( |
| 115 | + left.expected_tick_time, |
| 116 | + right.expected_tick_time, |
| 117 | + ) |
| 118 | + |
| 119 | + # 2. Compare the list of ClocksInfo objects |
| 120 | + sleeps_diff = _compare_sleep_infos_list( |
| 121 | + left.sleep_infos, right.sleep_infos, format_val |
| 122 | + ) |
| 123 | + if sleeps_diff: |
| 124 | + errors["sleep_infos"] = sleeps_diff |
| 125 | + |
| 126 | + return errors |
| 127 | + |
| 128 | + |
| 129 | +def _compare_sleep_infos_list( |
| 130 | + left: Sequence[ClocksInfo], |
| 131 | + right: Sequence[ClocksInfo], |
| 132 | + format_val: Callable[[object], str], |
| 133 | +) -> list[str]: |
| 134 | + """Compare two lists of ClocksInfo objects and return a list of error strings.""" |
| 135 | + if len(left) != len(right): |
| 136 | + return [ |
| 137 | + f"List lengths differ: {len(left)} != {len(right)}", |
| 138 | + f" {left!r}", |
| 139 | + " !=", |
| 140 | + f" {right!r}", |
| 141 | + ] |
| 142 | + |
| 143 | + diffs: list[str] = [] |
| 144 | + for i, (l_clock, r_clock) in enumerate(zip(left, right)): |
| 145 | + if l_clock != r_clock: |
| 146 | + diffs.append(f"Item at index [{i}] differs:") |
| 147 | + # Get detailed diffs for the fields inside the ClocksInfo object |
| 148 | + for field in l_clock.__dataclass_fields__: |
| 149 | + l_val = getattr(l_clock, field) |
| 150 | + r_val = getattr(r_clock, field) |
| 151 | + if l_val != r_val: |
| 152 | + diffs.append(f" Attribute '{field}':") |
| 153 | + diffs.append(f" - {format_val(l_val)}") |
| 154 | + diffs.append(f" + {format_val(r_val)}") |
| 155 | + return diffs |
0 commit comments