Skip to content

Commit d6c872f

Browse files
committed
move get transformed to util
1 parent da14207 commit d6c872f

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import pymc3 as pm
9+
import pymc3.util
910
from pymc3.theanof import change_flags
1011
from pymc3.variational.approximations import (
1112
MeanFieldGroup, FullRankGroup,
@@ -16,7 +17,7 @@
1617
ADVI, FullRankADVI, SVGD, NFVI, ASVGD,
1718
fit
1819
)
19-
from pymc3.variational import flows, opvi
20+
from pymc3.variational import flows
2021
from pymc3.variational.opvi import Approximation, Group
2122

2223
from . import models
@@ -109,7 +110,7 @@ def test_init_groups(three_var_model, raises, grouping):
109110
if g is None:
110111
pass
111112
else:
112-
assert set(g) == set(ig.group)
113+
assert set(pm.util.get_transformed(z) for z in g) == set(ig.group)
113114
else:
114115
assert approx.ndim == three_var_model.ndim
115116

@@ -143,7 +144,7 @@ def three_var_approx(three_var_model, three_var_groups):
143144

144145
def test_sample_simple(three_var_approx):
145146
trace = three_var_approx.sample(500)
146-
assert set(trace.varnames) == {'one', 'two', 'three'}
147+
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
147148
assert len(trace) == 500
148149
assert trace[0]['one'].shape == (10, 2)
149150
assert trace[0]['two'].shape == (10, )
@@ -174,7 +175,7 @@ def parametric_grouped_approxes(request):
174175

175176
@pytest.fixture
176177
def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial):
177-
dsize = np.prod(opvi.get_transformed(three_var_model.one).dshape[1:])
178+
dsize = np.prod(pymc3.util.get_transformed(three_var_model.one).dshape[1:])
178179
cls, kw = parametric_grouped_approxes
179180
spec = cls.get_param_spec_for(d=dsize, **kw)
180181
params = dict()

pymc3/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,8 @@ def update_start_vals(a, b, model):
144144

145145
a.update({k: v for k, v in b.items() if k not in a})
146146

147+
148+
def get_transformed(z):
149+
if hasattr(z, 'transformed'):
150+
z = z.transformed
151+
return z

pymc3/variational/opvi.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,23 @@
3131
https://arxiv.org/abs/1610.09033 (2016)
3232
"""
3333

34-
import warnings
35-
import itertools
3634
import collections
35+
import itertools
36+
import warnings
37+
3738
import numpy as np
3839
import theano
3940
import theano.tensor as tt
41+
4042
import pymc3 as pm
43+
from pymc3.util import get_transformed
4144
from .updates import adagrad_window
42-
from ..model import modelcontext
4345
from ..blocking import (
4446
ArrayOrdering, DictToArrayBijection, VarMap
4547
)
46-
from ..util import get_default_varnames
48+
from ..model import modelcontext
4749
from ..theanof import tt_rng, memoize, change_flags, identity
48-
50+
from ..util import get_default_varnames
4951

5052
__all__ = [
5153
'ObjectiveFunction',
@@ -112,12 +114,6 @@ def try_to_set_test_value(node_in, node_out, s):
112114
o.tag.test_value = tv
113115

114116

115-
def get_transformed(z):
116-
if hasattr(z, 'transformed'):
117-
z = z.transformed
118-
return z
119-
120-
121117
class ObjectiveUpdates(theano.OrderedUpdates):
122118
"""OrderedUpdates extension for storing loss
123119
"""

0 commit comments

Comments
 (0)