Skip to content

Commit c58f10b

Browse files
authored
Fix rewrite weakref leak (#1660)
Calling `lru_cache` on instance methods causes a leak. Fixed as suggested in https://rednafi.com/python/lru-cache-on-methods/
1 parent a99feb5 commit c58f10b

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,13 +1073,15 @@ def __init__(self):
10731073
defaultdict(lambda: defaultdict(list))
10741074
)
10751075
self.untracked_rewrites: list[NodeRewriter] = []
1076+
self.get_trackers = functools.cache(self._get_trackers)
10761077
self._cached_composed_mro = None
10771078

10781079
def add_tracker(self, rw: NodeRewriter):
10791080
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
10801081
if self._cached_composed_mro is not None:
10811082
# We shouldn't actually add_trackers after the first call to get_trackers
10821083
# But just to be safe we kill the cache here
1084+
self.get_trackers = functools.cache(self._get_trackers)
10831085
self._cached_composed_mro = None
10841086

10851087
tracks = rw.tracks()
@@ -1107,8 +1109,7 @@ def add_tracker(self, rw: NodeRewriter):
11071109
else:
11081110
self.tracked_instances[c].append(rw)
11091111

1110-
@functools.cache
1111-
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1112+
def _get_trackers(self, op: Op) -> list[NodeRewriter]:
11121113
"""Get all the rewrites applicable to an `Op`."""
11131114

11141115
if self._cached_composed_mro is None:

tests/graph/rewriting/test_basic.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import gc
2+
import operator
3+
14
import pytest
25

36
from pytensor.configdefaults import config
7+
from pytensor.graph import rewrite_graph
48
from pytensor.graph.basic import Apply, Constant, equal_computations
59
from pytensor.graph.features import Feature
610
from pytensor.graph.fg import FunctionGraph
@@ -930,3 +934,44 @@ def perform(self, *args):
930934
local_rewriter_2,
931935
local_rewriter_1,
932936
]
937+
938+
939+
def test_rewrite_weakref_leak():
940+
"""Check we don't have weakref leak on our rewrites"""
941+
942+
def _growth(limit=10, peak_stats={}):
943+
"""Vendoring of objgraph.growth
944+
945+
Source: https://github.com/mgedmin/objgraph/blob/94b1ca61a11109547442701800292dcfc7f59fc8/objgraph.py#L253
946+
"""
947+
gc.collect()
948+
objects = gc.get_objects()
949+
950+
stats = {}
951+
for o in objects:
952+
n = type(o).__name__
953+
stats[n] = stats.get(n, 0) + 1
954+
955+
deltas = {}
956+
for name, count in stats.items():
957+
old_count = peak_stats.get(name, 0)
958+
if count > old_count:
959+
deltas[name] = count - old_count
960+
peak_stats[name] = count
961+
962+
deltas = sorted(deltas.items(), key=operator.itemgetter(1), reverse=True)
963+
964+
if limit:
965+
deltas = deltas[:limit]
966+
967+
return [(name, stats[name], delta) for name, delta in deltas]
968+
969+
x = vector("x")
970+
y = exp(x)
971+
972+
for i in range(20):
973+
rewrite_graph(y, clone=False)
974+
res = _growth()
975+
# Only start checking after warmup
976+
if i > 15:
977+
assert not res, "Object counts are still growing"

0 commit comments

Comments
 (0)