Skip to content

Commit e5396fd

Browse files
authored
Recursion detection (#31)
1 parent e15e921 commit e5396fd

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
]
3838

3939
extlinks = {
40-
"issue": ("https://github.com/crusaderky/recursive_diff/issues/%s", "#"),
41-
"pull": ("https://github.com/crusaderky/recursive_diff/pull/%s", "#"),
40+
"issue": ("https://github.com/crusaderky/recursive_diff/issues/%s", "#%s"),
41+
"pull": ("https://github.com/crusaderky/recursive_diff/pull/%s", "#%s"),
4242
}
4343

4444
autosummary_generate = True

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ v1.3.0 (unreleased)
88
-------------------
99
- Test against Python 3.13 and 3.14
1010
- Test against recent Pandas versions (tested up to 3.0 beta)
11-
- Fixed warnings in recent Pandas versions
11+
- Detect and handle recursion in data structures (:issue:`24`)
12+
- Fixed warnings in recent Pandas versions (:issue:`27`)
1213
- Bumped up minimum versions for all dependencies:
1314

1415
========== ====== ========

recursive_diff/recursive_diff.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import math
1010
import re
1111
from collections.abc import Collection, Hashable, Iterator
12+
from contextlib import suppress
1213
from typing import Any, Literal
1314

1415
import numpy as np
@@ -105,6 +106,8 @@ def recursive_diff(
105106
abs_tol=abs_tol,
106107
brief_dims=brief_dims,
107108
path=[],
109+
seen_lhs=[],
110+
seen_rhs=[],
108111
suppress_type_diffs=False,
109112
join="inner",
110113
)
@@ -118,13 +121,19 @@ def _recursive_diff(
118121
abs_tol: float,
119122
brief_dims: Collection[Hashable] | str,
120123
path: list[object],
124+
seen_lhs: list[int],
125+
seen_rhs: list[int],
121126
suppress_type_diffs: bool,
122127
join: Literal["inner", "outer"],
123128
) -> Iterator[str]:
124129
"""Recursive implementation of :func:`recursive_diff`
125130
126131
:param list path:
127132
list of nodes traversed so far, to be prepended to all error messages
133+
param list[int] seen_lhs:
134+
list of id() of all lhs objects traversed so far, to detect cycles
135+
param list[int] seen_rhs:
136+
list of id() of all rhs objects traversed so far, to detect cycles
128137
:param bool suppress_type_diffs:
129138
if True, don't print out messages about differences in type
130139
:param str join:
@@ -144,6 +153,34 @@ def diff(msg: str, print_path: list[object] = path) -> str:
144153
path_prefix += ": "
145154
return path_prefix + msg
146155

156+
# Detect recursion
157+
recursive_lhs = -1
158+
recursive_rhs = -1
159+
with suppress(ValueError):
160+
recursive_lhs = seen_lhs.index(id(lhs))
161+
with suppress(ValueError):
162+
recursive_rhs = seen_rhs.index(id(rhs))
163+
164+
if recursive_lhs >= 0 or recursive_rhs >= 0:
165+
if recursive_lhs != recursive_rhs:
166+
if recursive_lhs == -1:
167+
msg_lhs = "is not recursive"
168+
else:
169+
msg_lhs = f"recurses to {path[: recursive_lhs + 1]}"
170+
if recursive_rhs == -1:
171+
msg_rhs = "is not recursive"
172+
else:
173+
msg_rhs = f"recurses to {path[: recursive_rhs + 1]}"
174+
yield diff(f"LHS {msg_lhs}; RHS {msg_rhs}")
175+
return
176+
177+
# Don't add internalized objects
178+
if not isinstance(lhs, (bool, int, float, type(None), str, bytes)):
179+
seen_lhs = [*seen_lhs, id(lhs)]
180+
if not isinstance(rhs, (bool, int, float, type(None), str, bytes)):
181+
seen_rhs = [*seen_rhs, id(rhs)]
182+
# End of recursion detection
183+
147184
# Build string representation of the two variables *before* casting
148185
lhs_repr = _str_trunc(lhs)
149186
rhs_repr = _str_trunc(rhs)
@@ -215,6 +252,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
215252
abs_tol=abs_tol,
216253
brief_dims=brief_dims,
217254
path=[*path, i],
255+
seen_lhs=seen_lhs,
256+
seen_rhs=seen_rhs,
218257
suppress_type_diffs=suppress_type_diffs,
219258
join=join,
220259
)
@@ -271,6 +310,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
271310
abs_tol=abs_tol,
272311
brief_dims=brief_dims,
273312
path=[*path, key],
313+
seen_lhs=seen_lhs,
314+
seen_rhs=seen_rhs,
274315
suppress_type_diffs=suppress_type_diffs,
275316
join=join,
276317
)
@@ -426,6 +467,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
426467
abs_tol=abs_tol,
427468
brief_dims=brief_dims,
428469
path=path,
470+
seen_lhs=seen_lhs,
471+
seen_rhs=seen_rhs,
429472
suppress_type_diffs=True,
430473
join=join,
431474
)
@@ -441,6 +484,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
441484
abs_tol=abs_tol,
442485
brief_dims=brief_dims,
443486
path=path,
487+
seen_lhs=seen_lhs,
488+
seen_rhs=seen_rhs,
444489
suppress_type_diffs=True,
445490
join=join,
446491
)

recursive_diff/tests/test_recursive_diff.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,48 @@ def test_dask(chunk_lhs, chunk_rhs):
857857
rhs = rhs.chunk(chunk_rhs)
858858

859859
check(lhs, rhs, "[data][x=2]: c != d")
860+
861+
862+
def test_recursion():
863+
lhs = []
864+
lhs.append(lhs)
865+
rhs = [1]
866+
check(lhs, rhs, "[0]: LHS recurses to [0]; RHS is not recursive")
867+
check(rhs, lhs, "[0]: LHS is not recursive; RHS recurses to [0]")
868+
869+
rhs = []
870+
rhs.append(rhs)
871+
check(lhs, rhs)
872+
873+
874+
def test_recursion_different_target_different():
875+
lhs = [[1, 2], [3, 4]]
876+
lhs.append(lhs[0])
877+
rhs = [[1, 2], [3, 4]]
878+
rhs.append(rhs[1])
879+
880+
check(
881+
lhs,
882+
rhs,
883+
"[2][0]: 1 != 3 (abs: 2.0e+00, rel: 2.0e+00)",
884+
"[2][1]: 2 != 4 (abs: 2.0e+00, rel: 1.0e+00)",
885+
)
886+
887+
888+
def test_recursion_different_target_identical():
889+
lhs = [[1, 2], [1, 2]]
890+
lhs.append(lhs[0])
891+
rhs = [[1, 2], [1, 2]]
892+
rhs.append(rhs[1])
893+
check(lhs, rhs)
894+
895+
896+
def test_repetition_is_not_recursion():
897+
class C:
898+
def __eq__(self, other):
899+
return isinstance(other, C)
900+
901+
c1 = C()
902+
lhs = [c1, c1]
903+
rhs = [c1, C()]
904+
check(lhs, rhs)

0 commit comments

Comments
 (0)