Skip to content

Commit 87ac644

Browse files
committed
[core] weak reference observers to prevent leaks
1 parent 36b2d15 commit 87ac644

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

smcpp/_smcpp.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ cdef class _PyInferenceManager:
123123
cdef InferenceManager* _im
124124
cdef vector[int*] _obs_ptrs
125125

126+
cdef object __weakref__
127+
126128
def __my_cinit__(self, observations, hidden_states, im_id=None):
127129
self._im_id = im_id
128130
self.seed = 1

smcpp/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _init_penalty(self):
5959
def _init_optimizer(self, outdir, algorithm, xtol, ftol):
6060
self._optimizer = self._OPTIMIZER_CLS(self, algorithm, xtol, ftol)
6161
if outdir:
62-
self._optimizer.register(analysis_saver.AnalysisSaver(outdir))
62+
self._optimizer.register_plugin(analysis_saver.AnalysisSaver(outdir))
6363

6464
def rescale(self, x):
6565
return x / (2. * self._N0)
@@ -361,7 +361,7 @@ def _init_optimizer(self, outdir, algorithm, xtol, ftol, learn_rho=False):
361361
super()._init_optimizer(outdir, algorithm, xtol, ftol)
362362
if learn_rho:
363363
rho_bounds = 2. * self._N0 * np.array([1e-10, 1e-5])
364-
self._optimizer.register(
364+
self._optimizer.register_plugin(
365365
parameter_optimizer.ParameterOptimizer("rho", tuple(rho_bounds)))
366366

367367

@@ -391,7 +391,7 @@ def _validate_data(self):
391391

392392
def _init_optimizer(self, outdir, algorithm, xtol, ftol):
393393
super()._init_optimizer(outdir, algorithm, xtol, ftol)
394-
self._optimizer.register(parameter_optimizer.ParameterOptimizer("split",
394+
self._optimizer.register_plugin(parameter_optimizer.ParameterOptimizer("split",
395395
(0., self._max_split),
396396
"model"))
397397

smcpp/observe.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import
22
from abc import ABCMeta, abstractmethod
33
import wrapt
4+
import weakref
45

56

67
# Decorator to target specific messages.
@@ -27,19 +28,16 @@ def update(self, *args, **kwargs):
2728

2829
class Observable(object):
2930
def __init__(self):
30-
self.observers = []
31+
self.observers = weakref.WeakSet()
3132

3233
def register(self, observer):
33-
if not observer in self.observers:
34-
self.observers.append(observer)
34+
self.observers.add(observer)
3535

3636
def unregister(self, observer):
37-
if observer in self.observers:
38-
self.observers.remove(observer)
37+
self.observers.discard(observer)
3938

4039
def unregister_all(self):
41-
if self.observers:
42-
del self.observers[:]
40+
self.observers.clear()
4341

4442
def update_observers(self, *args, **kwargs):
4543
for observer in self.observers:

smcpp/optimize/optimizers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class AbstractOptimizer(Observable):
2121
'''
2222
def __init__(self, analysis, algorithm, xtol, ftol):
2323
Observable.__init__(self)
24+
self._plugins = []
2425
self._analysis = analysis
2526
self._algorithm = algorithm
2627
self._ftol = ftol
@@ -174,6 +175,10 @@ def _callback(self, xk):
174175
if self._delta < self._xtol:
175176
raise ConvergedException("delta=%f < xtol=%f" % (self._delta, self._xtol))
176177

178+
def register_plugin(self, p):
179+
self._plugins.append(p)
180+
self.register(p)
181+
177182
def update_observers(self, *args, **kwargs):
178183
kwargs.update({
179184
'optimizer': self,
@@ -190,7 +195,7 @@ def __init__(self, analysis, algorithm, xtol, ftol):
190195
for cls in OptimizerPlugin.__subclasses__():
191196
try:
192197
if not cls.DISABLED:
193-
self.register(cls())
198+
self.register_plugin(cls())
194199
except TypeError:
195200
# Only register listeners with null constructor
196201
pass

0 commit comments

Comments
 (0)