Skip to content

Commit 4ea8dde

Browse files
Add start_sigma to ADVI 2 (#6132)
* Add `start_sigma` to ADVI * add test for `start` and `start_sigma` plus minor fixes * inline _prepare_start_sigma and use expm1 * Undo changes to ASVGD so all variational inference tests pass
1 parent 2b39a0c commit 4ea8dde

File tree

3 files changed

+59
-10
lines changed

3 files changed

+59
-10
lines changed

pymc/tests/test_variational_inference.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,31 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
571571
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)
572572

573573

574+
def test_fit_start(inference_spec, simple_model):
575+
mu_init = 17
576+
mu_sigma_init = 13
577+
578+
with simple_model:
579+
if type(inference_spec()) == ASVGD:
580+
# ASVGD doesn't support the start argument
581+
return
582+
elif type(inference_spec()) == ADVI:
583+
has_start_sigma = True
584+
else:
585+
has_start_sigma = False
586+
587+
kw = {"start": {"mu": mu_init}}
588+
if has_start_sigma:
589+
kw.update({"start_sigma": {"mu": mu_sigma_init}})
590+
591+
with simple_model:
592+
inference = inference_spec(**kw)
593+
trace = inference.fit(n=0).sample(10000)
594+
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
595+
if has_start_sigma:
596+
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
597+
598+
574599
def test_profile(inference):
575600
inference.run_profiling(n=100).summary()
576601

pymc/variational/approximations.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,27 @@ def std(self):
6767
def __init_group__(self, group):
6868
super().__init_group__(group)
6969
if not self._check_user_params():
70-
self.shared_params = self.create_shared_params(self._kwargs.get("start", None))
70+
self.shared_params = self.create_shared_params(
71+
self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
72+
)
7173
self._finalize_init()
7274

73-
def create_shared_params(self, start=None):
75+
def create_shared_params(self, start=None, start_sigma=None):
76+
# NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
77+
# `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
78+
# variables are given by `self.group` and that the mapping between original variables and flat array is given
79+
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
80+
# future code changes that break this assumption.
7481
start = self._prepare_start(start)
75-
rho = np.zeros((self.ddim,))
82+
rho1 = np.zeros((self.ddim,))
83+
84+
if start_sigma is not None:
85+
for name, slice_, *_ in self.ordering.values():
86+
sigma = start_sigma.get(name)
87+
if sigma is not None:
88+
rho1[slice_] = np.log(np.expm1(np.abs(sigma)))
89+
rho = rho1
90+
7691
return {
7792
"mu": aesara.shared(pm.floatX(start), "mu"),
7893
"rho": aesara.shared(pm.floatX(rho), "rho"),

pymc/variational/inference.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _infmean(input_array):
257257
)
258258
)
259259
else:
260-
if n < 10:
260+
if n == 0:
261+
logger.info(f"Initialization only")
262+
elif n < 10:
261263
logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}")
262264
else:
263265
avg_loss = _infmean(scores[max(0, i - 1000) : i + 1])
@@ -433,8 +435,10 @@ class ADVI(KLqp):
433435
random_seed: None or int
434436
leave None to use package global RandomStream or other
435437
valid value to create instance specific one
436-
start: `Point`
438+
start: `dict[str, np.ndarray]` or `StartDict`
437439
starting point for inference
440+
start_sigma: `dict[str, np.ndarray]`
441+
starting standard deviation for inference, only available for method 'advi'
438442
439443
References
440444
----------
@@ -464,7 +468,7 @@ class FullRankADVI(KLqp):
464468
random_seed: None or int
465469
leave None to use package global RandomStream or other
466470
valid value to create instance specific one
467-
start: `Point`
471+
start: `dict[str, np.ndarray]` or `StartDict`
468472
starting point for inference
469473
470474
References
@@ -532,13 +536,11 @@ class SVGD(ImplicitGradient):
532536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
533537
temperature: float
534538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
535-
start: `dict`
539+
start: `dict[str, np.ndarray]` or `StartDict`
536540
initial point for inference
537541
random_seed: None or int
538542
leave None to use package global RandomStream or other
539543
valid value to create instance specific one
540-
start: `Point`
541-
starting point for inference
542544
kwargs: other keyword arguments passed to estimator
543545
544546
References
@@ -660,6 +662,7 @@ def fit(
660662
model=None,
661663
random_seed=None,
662664
start=None,
665+
start_sigma=None,
663666
inf_kwargs=None,
664667
**kwargs,
665668
):
@@ -684,8 +687,10 @@ def fit(
684687
valid value to create instance specific one
685688
inf_kwargs: dict
686689
additional kwargs passed to :class:`Inference`
687-
start: `Point`
690+
start: `dict[str, np.ndarray]` or `StartDict`
688691
starting point for inference
692+
start_sigma: `dict[str, np.ndarray]`
693+
starting standard deviation for inference, only available for method 'advi'
689694
690695
Other Parameters
691696
----------------
@@ -728,6 +733,10 @@ def fit(
728733
inf_kwargs["random_seed"] = random_seed
729734
if start is not None:
730735
inf_kwargs["start"] = start
736+
if start_sigma is not None:
737+
if method != "advi":
738+
raise NotImplementedError("start_sigma is only available for method advi")
739+
inf_kwargs["start_sigma"] = start_sigma
731740
if model is None:
732741
model = pm.modelcontext(model)
733742
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)

0 commit comments

Comments
 (0)