Skip to content

Commit 3a724c8

Browse files
ferrinetwiecki
authored andcommitted
Update notebook
1 parent 7999a64 commit 3a724c8

File tree

1 file changed

+135
-7
lines changed

1 file changed

+135
-7
lines changed

docs/source/notebooks/variational_api_quickstart.ipynb

Lines changed: 135 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"metadata": {},
3838
"source": [
3939
"## Basic setup\n",
40-
"We do not need comples models to play with VI API, instead we'll have a simple mixture model"
40+
"We do not need complex models to play with VI API, instead we'll have a simple mixture model"
4141
]
4242
},
4343
{
@@ -61,7 +61,7 @@
6161
"cell_type": "markdown",
6262
"metadata": {},
6363
"source": [
64-
"We can't compute analitical expectations quickly here. Instead we can get approximations for it with MC methods. Lets use NUTS first. It required these variables to be in deterministic list"
64+
"We can't compute analytical expectations quickly here. Instead we can get approximations for it with MC methods. Lets use NUTS first. It required these variables to be in deterministic list"
6565
]
6666
},
6767
{
@@ -124,9 +124,9 @@
124124
"cell_type": "markdown",
125125
"metadata": {},
126126
"source": [
127-
"Looks good, we can see multomodality matters. Moreover we have samples for $x^2$ and $sin(x)$. Mut running MCMC every time we need to calculate an expression is not the thing needed in many experiments. Moreover you should know in advance what exactly you want to see in trace and call `Deterministic(.)` on it. \n",
127+
"Looks good, we can see multomodality matters. Moreover we have samples for $x^2$ and $sin(x)$. There is one drawback, you should know in advance what exactly you want to see in trace and call `Deterministic(.)` on it.\n",
128128
"\n",
129-
"VI API is about the opposite approach. First You have a model, you don inference on it, then experiments comming. Let's do the same setup without deterministics"
129+
"VI API is about the opposite approach. You do inference on model, then experiments come after. Let's do the same setup without deterministics"
130130
]
131131
},
132132
{
@@ -405,6 +405,40 @@
405405
"advi.approx"
406406
]
407407
},
408+
{
409+
"cell_type": "markdown",
410+
"metadata": {},
411+
"source": [
412+
"Different approximations have different parameters. In MeanField case we use $\\rho$ and $\\mu$ (inspired by [Bayes by BackProp](https://arxiv.org/abs/1505.05424))"
413+
]
414+
},
415+
{
416+
"cell_type": "code",
417+
"execution_count": 59,
418+
"metadata": {},
419+
"outputs": [
420+
{
421+
"data": {
422+
"text/plain": [
423+
"{'mu': mu, 'rho': rho}"
424+
]
425+
},
426+
"execution_count": 59,
427+
"metadata": {},
428+
"output_type": "execute_result"
429+
}
430+
],
431+
"source": [
432+
"advi.approx.shared_params"
433+
]
434+
},
435+
{
436+
"cell_type": "markdown",
437+
"metadata": {},
438+
"source": [
439+
"But having convinient shortcuts happens to be usefull sometimes. One of most frequent cases is specifying mass matrix for NUTS"
440+
]
441+
},
408442
{
409443
"cell_type": "code",
410444
"execution_count": 16,
@@ -446,6 +480,55 @@
446480
")"
447481
]
448482
},
483+
{
484+
"cell_type": "code",
485+
"execution_count": 58,
486+
"metadata": {},
487+
"outputs": [
488+
{
489+
"name": "stdout",
490+
"output_type": "stream",
491+
"text": [
492+
"\n",
493+
" Helper class to record arbitrary stats during VI\n",
494+
"\n",
495+
" It is possible to pass a function that takes no arguments\n",
496+
" If call fails then (approx, hist, i) are passed\n",
497+
"\n",
498+
"\n",
499+
" Parameters\n",
500+
" ----------\n",
501+
" kwargs : key word arguments\n",
502+
" keys mapping statname to callable that records the stat\n",
503+
"\n",
504+
" Examples\n",
505+
" --------\n",
506+
" Consider we want time on each iteration \n",
507+
" >>> import time\n",
508+
" >>> tracker = Tracker(time=time.time)\n",
509+
" >>> with model:\n",
510+
" ... approx = pm.fit(callbacks=[tracker])\n",
511+
" \n",
512+
" Time can be accessed via :code:`tracker['time']` now\n",
513+
" For more complex summary one can use callable that takes\n",
514+
" (approx, hist, i) as arguments\n",
515+
" >>> with model:\n",
516+
" ... my_callable = lambda ap, h, i: h[-1]\n",
517+
" ... tracker = Tracker(some_stat=my_callable)\n",
518+
" ... approx = pm.fit(callbacks=[tracker])\n",
519+
" \n",
520+
" Multiple stats are valid too\n",
521+
" >>> with model:\n",
522+
" ... tracker = Tracker(some_stat=my_callable, time=time.time)\n",
523+
" ... approx = pm.fit(callbacks=[tracker])\n",
524+
" \n"
525+
]
526+
}
527+
],
528+
"source": [
529+
"print(pm.callbacks.Tracker.__doc__)"
530+
]
531+
},
449532
{
450533
"cell_type": "markdown",
451534
"metadata": {},
@@ -511,7 +594,7 @@
511594
"cell_type": "markdown",
512595
"metadata": {},
513596
"source": [
514-
"That picture is very informative. We can see how poor mean converges and thar different value for it do not change elbo significantly. As we are using OO API, we can continue inference to get some visual convergence"
597+
"That picture is very informative. We can see how poor mean converges and that different values for it do not change elbo significantly. As we are using OO API, we can continue inference to get some visual convergence"
515598
]
516599
},
517600
{
@@ -848,7 +931,7 @@
848931
"cell_type": "markdown",
849932
"metadata": {},
850933
"source": [
851-
"Every time we get different value for the same theano node. That is because it is stochastic"
934+
"Every time we get different value for the same theano node. That is because it is stochastic. After replacements we are free and do not depend on pymc3 model. We now depend on approximation. Changing it will change the distribution for stochastic nodes"
852935
]
853936
},
854937
{
@@ -1006,12 +1089,57 @@
10061089
"ass_.mean(0).eval() # mean"
10071090
]
10081091
},
1092+
{
1093+
"cell_type": "markdown",
1094+
"metadata": {},
1095+
"source": [
1096+
"Symbolic sample size is OK too"
1097+
]
1098+
},
1099+
{
1100+
"cell_type": "code",
1101+
"execution_count": 61,
1102+
"metadata": {},
1103+
"outputs": [],
1104+
"source": [
1105+
"i = theano.tensor.iscalar('i')\n",
1106+
"i.tag.test_value = 1\n",
1107+
"ass2_ = svgd_approx.sample_node(a, size=i)"
1108+
]
1109+
},
1110+
{
1111+
"cell_type": "code",
1112+
"execution_count": 63,
1113+
"metadata": {},
1114+
"outputs": [
1115+
{
1116+
"data": {
1117+
"text/plain": [
1118+
"((100,), (10000,))"
1119+
]
1120+
},
1121+
"execution_count": 63,
1122+
"metadata": {},
1123+
"output_type": "execute_result"
1124+
}
1125+
],
1126+
"source": [
1127+
"ass2_.eval({i: 100}).shape, ass2_.eval({i: 10000}).shape"
1128+
]
1129+
},
1130+
{
1131+
"cell_type": "markdown",
1132+
"metadata": {},
1133+
"source": [
1134+
"But unfortunately only scalar size is supported."
1135+
]
1136+
},
10091137
{
10101138
"cell_type": "markdown",
10111139
"metadata": {},
10121140
"source": [
10131141
"### creating custom fuction\n",
1014-
"What about mode_replacements argument, it can be very usefull sometimes. Suppose you have th following setup:"
1142+
"What about mode_replacements argument, it can be very usefull sometimes. Suppose you have the following setup:"
10151143
]
10161144
},
10171145
{

0 commit comments

Comments
 (0)