Skip to content

Commit 10c9330

Browse files
authored
add beta-vae (#3071)
* add betavee, get strange test fail * add test backward compat * docs * improve docs, increment release notes
1 parent d1d2aa2 commit 10c9330

File tree

6 files changed

+209
-12
lines changed

6 files changed

+209
-12
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
final trace.
2626
- Add `model_to_graphviz` (which uses the optional dependency `graphviz`) to
2727
plot a directed graph of a PyMC3 model using plate notation.
28+
- Add beta-ELBO variational inference as in beta-VAE model (Christopher P. Burgess et al. NIPS, 2017)
2829

2930
### Fixes
3031

3132
- Fixed `KeyError` raised when only subset of variables are specified to be recorded in the trace.
3233
- Removed unused `repeat=None` arguments from all `random()` methods in distributions.
3334
- Deprecated the `sigma` argument in `MarginalSparse.marginal_likelihood` in favor of `noise`
3435
- Fixed unexpected behavior in `random`. Now the `random` functionality is more robust and will work better for `sample_prior` when that is implemented.
36+
- Fixed `scale_cost_to_minibatch` behaviour, previously this was not working and always `False`
3537

3638
## PyMC 3.4.1 (April 18 2018)
3739

pymc3/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,14 @@ def varlogpt(self):
737737
"""Theano scalar of log-probability of the unobserved random variables
738738
(excluding deterministic)."""
739739
with self:
740-
factors = [var.logpt for var in self.vars]
740+
factors = [var.logpt for var in self.free_RVs]
741+
return tt.sum(factors)
742+
743+
@property
744+
def datalogpt(self):
745+
with self:
746+
factors = [var.logpt for var in self.observed_RVs]
747+
factors += [tt.sum(factor) for factor in self.potentials]
741748
return tt.sum(factors)
742749

743750
@property

pymc3/tests/test_variational_inference.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,93 @@ def test_elbo():
459459
np.testing.assert_allclose(elbo_mc, elbo_true, rtol=0, atol=1e-1)
460460

461461

462+
@pytest.mark.parametrize(
463+
'aux_total_size',
464+
range(2, 10, 3)
465+
)
466+
def test_scale_cost_to_minibatch_works(aux_total_size):
467+
mu0 = 1.5
468+
sigma = 1.0
469+
y_obs = np.array([1.6, 1.4])
470+
beta = len(y_obs)/float(aux_total_size)
471+
post_mu = np.array([1.88], dtype=theano.config.floatX)
472+
post_sd = np.array([1], dtype=theano.config.floatX)
473+
474+
# TODO: theano_config
475+
# with pm.Model(theano_config=dict(floatX='float64')):
476+
# did not not work as expected
477+
# there were some numeric problems, so float64 is forced
478+
with pm.theanof.change_flags(floatX='float64', warn_float64='ignore'):
479+
with pm.Model():
480+
assert theano.config.floatX == 'float64'
481+
assert theano.config.warn_float64 == 'ignore'
482+
mu = pm.Normal('mu', mu=mu0, sd=sigma)
483+
pm.Normal('y', mu=mu, sd=1, observed=y_obs, total_size=aux_total_size)
484+
# Create variational gradient tensor
485+
mean_field_1 = MeanField()
486+
assert mean_field_1.scale_cost_to_minibatch
487+
mean_field_1.shared_params['mu'].set_value(post_mu)
488+
mean_field_1.shared_params['rho'].set_value(np.log(np.exp(post_sd) - 1))
489+
490+
with pm.theanof.change_flags(compute_test_value='off'):
491+
elbo_via_total_size_scaled = -pm.operators.KL(mean_field_1)()(10000)
492+
493+
with pm.Model():
494+
mu = pm.Normal('mu', mu=mu0, sd=sigma)
495+
pm.Normal('y', mu=mu, sd=1, observed=y_obs, total_size=aux_total_size)
496+
# Create variational gradient tensor
497+
mean_field_2 = MeanField()
498+
assert mean_field_1.scale_cost_to_minibatch
499+
mean_field_2.scale_cost_to_minibatch = False
500+
assert not mean_field_2.scale_cost_to_minibatch
501+
mean_field_2.shared_params['mu'].set_value(post_mu)
502+
mean_field_2.shared_params['rho'].set_value(np.log(np.exp(post_sd) - 1))
503+
504+
with pm.theanof.change_flags(compute_test_value='off'):
505+
elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000)
506+
507+
np.testing.assert_allclose(elbo_via_total_size_unscaled.eval(),
508+
elbo_via_total_size_scaled.eval() * pm.floatX(1 / beta), rtol=0.02, atol=1e-1)
509+
510+
511+
@pytest.mark.parametrize(
512+
'aux_total_size',
513+
range(2, 10, 3)
514+
)
515+
def test_elbo_beta_kl(aux_total_size):
516+
mu0 = 1.5
517+
sigma = 1.0
518+
y_obs = np.array([1.6, 1.4])
519+
beta = len(y_obs)/float(aux_total_size)
520+
post_mu = np.array([1.88], dtype=theano.config.floatX)
521+
post_sd = np.array([1], dtype=theano.config.floatX)
522+
with pm.theanof.change_flags(floatX='float64', warn_float64='ignore'):
523+
with pm.Model():
524+
mu = pm.Normal('mu', mu=mu0, sd=sigma)
525+
pm.Normal('y', mu=mu, sd=1, observed=y_obs, total_size=aux_total_size)
526+
# Create variational gradient tensor
527+
mean_field_1 = MeanField()
528+
mean_field_1.scale_cost_to_minibatch = True
529+
mean_field_1.shared_params['mu'].set_value(post_mu)
530+
mean_field_1.shared_params['rho'].set_value(np.log(np.exp(post_sd) - 1))
531+
532+
with pm.theanof.change_flags(compute_test_value='off'):
533+
elbo_via_total_size_scaled = -pm.operators.KL(mean_field_1)()(10000)
534+
535+
with pm.Model():
536+
mu = pm.Normal('mu', mu=mu0, sd=sigma)
537+
pm.Normal('y', mu=mu, sd=1, observed=y_obs)
538+
# Create variational gradient tensor
539+
mean_field_3 = MeanField()
540+
mean_field_3.shared_params['mu'].set_value(post_mu)
541+
mean_field_3.shared_params['rho'].set_value(np.log(np.exp(post_sd) - 1))
542+
543+
with pm.theanof.change_flags(compute_test_value='off'):
544+
elbo_via_beta_kl = -pm.operators.KL(mean_field_3, beta=beta)()(10000)
545+
546+
np.testing.assert_allclose(elbo_via_total_size_scaled.eval(), elbo_via_beta_kl.eval(), rtol=0, atol=1e-1)
547+
548+
462549
@pytest.fixture(
463550
'module',
464551
params=[True, False],
@@ -581,6 +668,8 @@ def fit_kwargs(inference, use_minibatch):
581668
}
582669
if use_minibatch:
583670
key = 'mini'
671+
# backward compat for PR#3071
672+
inference.approx.scale_cost_to_minibatch = False
584673
else:
585674
key = 'full'
586675
return _select[(type(inference), key)]

pymc3/variational/inference.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'FullRankADVI',
2323
'SVGD',
2424
'ASVGD',
25+
'NFVI',
2526
'Inference',
2627
'ImplicitGradient',
2728
'KLqp',
@@ -272,15 +273,28 @@ class KLqp(Inference):
272273
"""**Kullback Leibler Divergence Inference**
273274
274275
General approach to fit Approximations that define :math:`logq`
275-
by maximizing ELBO (Evidence Lower Bound).
276+
by maximizing ELBO (Evidence Lower Bound). In some cases
277+
rescaling the regularization term KL may be beneficial
278+
279+
.. math::
280+
281+
ELBO_\beta = \log p(D|\theta) - \beta KL(q||p)
276282
277283
Parameters
278284
----------
279285
approx : :class:`Approximation`
280286
Approximation to fit, it is required to have `logQ`
287+
beta : float
288+
Scales the regularization term in ELBO (see Christopher P. Burgess et al., 2017)
289+
290+
References
291+
----------
292+
- Christopher P. Burgess et al. (NIPS, 2017)
293+
Understanding disentangling in :math:`\beta`-VAE
294+
arXiv preprint 1804.03599
281295
"""
282-
def __init__(self, approx):
283-
super(KLqp, self).__init__(KL, approx, None)
296+
def __init__(self, approx, beta=1.):
297+
super(KLqp, self).__init__(KL, approx, None, beta=beta)
284298

285299

286300
class ADVI(KLqp):

pymc3/variational/operators.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,33 @@
1414
class KL(Operator):
1515
R"""**Operator based on Kullback Leibler Divergence**
1616
17+
This operator constructs Evidence Lower Bound (ELBO) objective
18+
19+
.. math::
20+
21+
ELBO_\beta = \log p(D|\theta) - \beta KL(q||p)
22+
23+
where
24+
1725
.. math::
1826
1927
KL[q(v)||p(v)] = \int q(v)\log\frac{q(v)}{p(v)}dv
28+
29+
30+
Parameters
31+
----------
32+
approx : :class:`Approximation`
33+
Approximation used for inference
34+
beta : float
35+
Beta parameter for KL divergence, scales the regularization term.
2036
"""
2137

38+
def __init__(self, approx, beta=1.):
39+
Operator.__init__(self, approx)
40+
self.beta = pm.floatX(beta)
41+
2242
def apply(self, f):
23-
return self.logq_norm - self.logp_norm
43+
return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)
2444

2545
# SVGD Implementation
2646

@@ -76,6 +96,8 @@ class KSD(Operator):
7696
----------
7797
approx : :class:`Approximation`
7898
Approximation used for inference
99+
temperature: float
100+
Temperature for Stein gradient
79101
80102
References
81103
----------

pymc3/variational/opvi.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,12 @@ def __init__(self, approx):
389389

390390
inputs = property(lambda self: self.approx.inputs)
391391
logp = property(lambda self: self.approx.logp)
392+
varlogp = property(lambda self: self.approx.varlogp)
393+
datalogp = property(lambda self: self.approx.datalogp)
392394
logq = property(lambda self: self.approx.logq)
393395
logp_norm = property(lambda self: self.approx.logp_norm)
396+
varlogp_norm = property(lambda self: self.approx.varlogp_norm)
397+
datalogp_norm = property(lambda self: self.approx.datalogp_norm)
394398
logq_norm = property(lambda self: self.approx.logq_norm)
395399
model = property(lambda self: self.approx.model)
396400

@@ -1298,7 +1302,10 @@ def symbolic_normalizing_constant(self):
12981302
"""*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`.
12991303
Here the effect is controlled by `self.scale_cost_to_minibatch`
13001304
"""
1301-
t = tt.max(self.collect('symbolic_normalizing_constant'))
1305+
t = tt.max(
1306+
self.collect('symbolic_normalizing_constant') + [
1307+
var.scaling for var in self.model.observed_RVs
1308+
])
13021309
t = tt.switch(self._scale_cost_to_minibatch, t,
13031310
tt.constant(1, dtype=t.dtype))
13041311
return pm.floatX(t)
@@ -1318,28 +1325,83 @@ def logq_norm(self):
13181325
"""*Dev* - collects `logQ` for all groups and normalizes it"""
13191326
return self.logq / self.symbolic_normalizing_constant
13201327

1328+
@node_property
1329+
def _sized_symbolic_varlogp_and_datalogp(self):
1330+
"""*Dev* - computes sampled prior term from model via `theano.scan`"""
1331+
varlogp_s, datalogp_s = self.symbolic_sample_over_posterior(
1332+
[self.model.varlogpt, self.model.datalogpt])
1333+
return varlogp_s, datalogp_s # both shape (s,)
1334+
1335+
@node_property
1336+
def sized_symbolic_varlogp(self):
1337+
"""*Dev* - computes sampled prior term from model via `theano.scan`"""
1338+
return self._sized_symbolic_varlogp_and_datalogp[0] # shape (s,)
1339+
1340+
@node_property
1341+
def sized_symbolic_datalogp(self):
1342+
"""*Dev* - computes sampled data term from model via `theano.scan`"""
1343+
return self._sized_symbolic_varlogp_and_datalogp[1] # shape (s,)
1344+
13211345
@node_property
13221346
def sized_symbolic_logp(self):
1323-
"""*Dev* - computes sampled `logP` from model via `theano.scan`"""
1324-
free_logp_local = self.symbolic_sample_over_posterior(self.model.logpt)
1325-
return free_logp_local # shape (s,)
1347+
"""*Dev* - computes sampled logP from model via `theano.scan`"""
1348+
return self.sized_symbolic_varlogp + self.sized_symbolic_datalogp # shape (s,)
13261349

13271350
@node_property
13281351
def logp(self):
13291352
"""*Dev* - computes :math:`E_{q}(logP)` from model via `theano.scan` that can be optimized later"""
1330-
return self.sized_symbolic_logp.mean(0)
1353+
return self.varlogp + self.datalogp
1354+
1355+
@node_property
1356+
def varlogp(self):
1357+
"""*Dev* - computes :math:`E_{q}(prior term)` from model via `theano.scan` that can be optimized later"""
1358+
return self.sized_symbolic_varlogp.mean(0)
1359+
1360+
@node_property
1361+
def datalogp(self):
1362+
"""*Dev* - computes :math:`E_{q}(data term)` from model via `theano.scan` that can be optimized later"""
1363+
return self.sized_symbolic_datalogp.mean(0)
1364+
1365+
@node_property
1366+
def _single_symbolic_varlogp_and_datalogp(self):
1367+
"""*Dev* - computes sampled prior term from model via `theano.scan`"""
1368+
varlogp, datalogp = self.symbolic_single_sample(
1369+
[self.model.varlogpt, self.model.datalogpt])
1370+
return varlogp, datalogp
1371+
1372+
@node_property
1373+
def single_symbolic_varlogp(self):
1374+
"""*Dev* - for single MC sample estimate of :math:`E_{q}(prior term)` `theano.scan`
1375+
is not needed and code can be optimized"""
1376+
return self._single_symbolic_varlogp_and_datalogp[0]
1377+
1378+
@node_property
1379+
def single_symbolic_datalogp(self):
1380+
"""*Dev* - for single MC sample estimate of :math:`E_{q}(data term)` `theano.scan`
1381+
is not needed and code can be optimized"""
1382+
return self._single_symbolic_varlogp_and_datalogp[1]
13311383

13321384
@node_property
13331385
def single_symbolic_logp(self):
13341386
"""*Dev* - for single MC sample estimate of :math:`E_{q}(logP)` `theano.scan`
13351387
is not needed and code can be optimized"""
1336-
return self.symbolic_single_sample(self.model.logpt)
1388+
return self.single_symbolic_datalogp + self.single_symbolic_varlogp
13371389

13381390
@node_property
13391391
def logp_norm(self):
13401392
"""*Dev* - normalized :math:`E_{q}(logP)`"""
13411393
return self.logp / self.symbolic_normalizing_constant
13421394

1395+
@node_property
1396+
def varlogp_norm(self):
1397+
"""*Dev* - normalized :math:`E_{q}(prior term)`"""
1398+
return self.varlogp / self.symbolic_normalizing_constant
1399+
1400+
@node_property
1401+
def datalogp_norm(self):
1402+
"""*Dev* - normalized :math:`E_{q}(data term)`"""
1403+
return self.datalogp / self.symbolic_normalizing_constant
1404+
13431405
@property
13441406
def replacements(self):
13451407
"""*Dev* - all replacements from groups to replace PyMC random variables with approximation"""
@@ -1437,7 +1499,8 @@ def get_optimization_replacements(self, s, d):
14371499
repl = collections.OrderedDict()
14381500
# avoid scan if size is constant and equal to one
14391501
if isinstance(s, int) and (s == 1) or s is None:
1440-
repl[self.logp] = self.single_symbolic_logp
1502+
repl[self.varlogp] = self.single_symbolic_varlogp
1503+
repl[self.datalogp] = self.single_symbolic_datalogp
14411504
return repl
14421505

14431506
@change_flags(compute_test_value='off')

0 commit comments

Comments
 (0)