Skip to content

Commit 0bb5211

Browse files
ferrinetwiecki
authored andcommitted
Make Approximation serializeable
* fix #2615 * better solution * fix recursion error
1 parent c438309 commit 0bb5211

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def three_var_approx(three_var_model, three_var_groups):
142142
return approx
143143

144144

145+
@pytest.fixture
146+
def three_var_approx_single_group_mf(three_var_model):
147+
return MeanField(model=three_var_model)
148+
149+
145150
def test_sample_simple(three_var_approx):
146151
trace = three_var_approx.sample(500)
147152
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
@@ -701,6 +706,27 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
701706
pytest.skip('Does not support rowwise grouping')
702707

703708

709+
def test_pickle_approx(three_var_approx):
710+
import pickle
711+
dump = pickle.dumps(three_var_approx)
712+
new = pickle.loads(dump)
713+
assert new.sample(1)
714+
715+
716+
def test_pickle_single_group(three_var_approx_single_group_mf):
717+
import pickle
718+
dump = pickle.dumps(three_var_approx_single_group_mf)
719+
new = pickle.loads(dump)
720+
assert new.sample(1)
721+
722+
723+
def test_pickle_approx_aevb(three_var_aevb_approx):
724+
import pickle
725+
dump = pickle.dumps(three_var_aevb_approx)
726+
new = pickle.loads(dump)
727+
assert new.sample(1000)
728+
729+
704730
@pytest.fixture('module')
705731
def binomial_model():
706732
n_samples = 100

pymc3/variational/approximations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,12 @@ def __init__(self, *args, **kwargs):
539539
def __getattr__(self, item):
540540
return getattr(self.groups[0], item)
541541

542+
def __getstate__(self):
543+
return self.__dict__.copy()
544+
545+
def __setstate__(self, state):
546+
self.__dict__.update(state)
547+
542548

543549
class MeanField(SingleGroupApproximation):
544550
__doc__ = """**Single Group Mean Field Approximation**

pymc3/variational/opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def group_for_short_name(cls, name):
721721
.format(name, cls.__name_registry))
722722
return cls.__name_registry[name.lower()]
723723

724-
def __new__(cls, group, vfam=None, params=None, *args, **kwargs):
724+
def __new__(cls, group=None, vfam=None, params=None, *args, **kwargs):
725725
if cls is Group:
726726
if vfam is not None and params is not None:
727727
raise TypeError('Cannot call Group with both `vfam` and `params` provided')

0 commit comments

Comments
 (0)