99import math
1010import re
1111from collections .abc import Collection , Hashable , Iterator
12+ from contextlib import suppress
1213from typing import Any , Literal
1314
1415import numpy as np
@@ -105,6 +106,8 @@ def recursive_diff(
105106 abs_tol = abs_tol ,
106107 brief_dims = brief_dims ,
107108 path = [],
109+ seen_lhs = [],
110+ seen_rhs = [],
108111 suppress_type_diffs = False ,
109112 join = "inner" ,
110113 )
@@ -118,13 +121,19 @@ def _recursive_diff(
118121 abs_tol : float ,
119122 brief_dims : Collection [Hashable ] | str ,
120123 path : list [object ],
124+ seen_lhs : list [int ],
125+ seen_rhs : list [int ],
121126 suppress_type_diffs : bool ,
122127 join : Literal ["inner" , "outer" ],
123128) -> Iterator [str ]:
124129 """Recursive implementation of :func:`recursive_diff`
125130
126131 :param list path:
127132 list of nodes traversed so far, to be prepended to all error messages
133+ param list[int] seen_lhs:
134+ list of id() of all lhs objects traversed so far, to detect cycles
135+ param list[int] seen_rhs:
136+ list of id() of all rhs objects traversed so far, to detect cycles
128137 :param bool suppress_type_diffs:
129138 if True, don't print out messages about differences in type
130139 :param str join:
@@ -144,6 +153,34 @@ def diff(msg: str, print_path: list[object] = path) -> str:
144153 path_prefix += ": "
145154 return path_prefix + msg
146155
156+ # Detect recursion
157+ recursive_lhs = - 1
158+ recursive_rhs = - 1
159+ with suppress (ValueError ):
160+ recursive_lhs = seen_lhs .index (id (lhs ))
161+ with suppress (ValueError ):
162+ recursive_rhs = seen_rhs .index (id (rhs ))
163+
164+ if recursive_lhs >= 0 or recursive_rhs >= 0 :
165+ if recursive_lhs != recursive_rhs :
166+ if recursive_lhs == - 1 :
167+ msg_lhs = "is not recursive"
168+ else :
169+ msg_lhs = f"recurses to { path [: recursive_lhs + 1 ]} "
170+ if recursive_rhs == - 1 :
171+ msg_rhs = "is not recursive"
172+ else :
173+ msg_rhs = f"recurses to { path [: recursive_rhs + 1 ]} "
174+ yield diff (f"LHS { msg_lhs } ; RHS { msg_rhs } " )
175+ return
176+
177+ # Don't add internalized objects
178+ if not isinstance (lhs , (bool , int , float , type (None ), str , bytes )):
179+ seen_lhs = [* seen_lhs , id (lhs )]
180+ if not isinstance (rhs , (bool , int , float , type (None ), str , bytes )):
181+ seen_rhs = [* seen_rhs , id (rhs )]
182+ # End of recursion detection
183+
147184 # Build string representation of the two variables *before* casting
148185 lhs_repr = _str_trunc (lhs )
149186 rhs_repr = _str_trunc (rhs )
@@ -215,6 +252,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
215252 abs_tol = abs_tol ,
216253 brief_dims = brief_dims ,
217254 path = [* path , i ],
255+ seen_lhs = seen_lhs ,
256+ seen_rhs = seen_rhs ,
218257 suppress_type_diffs = suppress_type_diffs ,
219258 join = join ,
220259 )
@@ -271,6 +310,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
271310 abs_tol = abs_tol ,
272311 brief_dims = brief_dims ,
273312 path = [* path , key ],
313+ seen_lhs = seen_lhs ,
314+ seen_rhs = seen_rhs ,
274315 suppress_type_diffs = suppress_type_diffs ,
275316 join = join ,
276317 )
@@ -426,6 +467,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
426467 abs_tol = abs_tol ,
427468 brief_dims = brief_dims ,
428469 path = path ,
470+ seen_lhs = seen_lhs ,
471+ seen_rhs = seen_rhs ,
429472 suppress_type_diffs = True ,
430473 join = join ,
431474 )
@@ -441,6 +484,8 @@ def diff(msg: str, print_path: list[object] = path) -> str:
441484 abs_tol = abs_tol ,
442485 brief_dims = brief_dims ,
443486 path = path ,
487+ seen_lhs = seen_lhs ,
488+ seen_rhs = seen_rhs ,
444489 suppress_type_diffs = True ,
445490 join = join ,
446491 )
0 commit comments