Skip to content

Commit 8edc54a

Browse files
committed
Update tutorial notebooks
1 parent c91c09d commit 8edc54a

File tree

3 files changed

+228
-150
lines changed

3 files changed

+228
-150
lines changed

docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb

Lines changed: 218 additions & 102 deletions
Large diffs are not rendered by default.

docs/source/tutorial_notebooks/Intro_Amortized_Posterior_Estimation.ipynb

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,6 @@
3636
"import bayesflow.diagnostics as diag"
3737
]
3838
},
39-
{
40-
"cell_type": "code",
41-
"execution_count": 4,
42-
"id": "racial-james",
43-
"metadata": {},
44-
"outputs": [],
45-
"source": [
46-
"%load_ext autoreload\n",
47-
"%autoreload 2"
48-
]
49-
},
5039
{
5140
"cell_type": "markdown",
5241
"id": "contemporary-arthritis",

docs/source/tutorial_notebooks/Linear ODE system.ipynb

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
],
4646
"source": [
4747
"sys.path.append(os.path.abspath(os.path.join('../../..')))\n",
48-
"from bayesflow.forward_inference import Prior, Simulator, GenerativeModel\n",
48+
"from bayesflow.simulation import Prior, Simulator, GenerativeModel\n",
4949
"from bayesflow.networks import InvertibleNetwork\n",
5050
"from bayesflow.amortized_inference import AmortizedPosterior\n",
5151
"from bayesflow.trainers import Trainer\n",
@@ -460,10 +460,7 @@
460460
"metadata": {},
461461
"outputs": [],
462462
"source": [
463-
"inference_net = InvertibleNetwork({\n",
464-
" 'n_params' : 6,\n",
465-
" 'n_coupling_layers' : 8,\n",
466-
"})"
463+
"inference_net = InvertibleNetwork(num_params=6, num_coupling_layers=8)"
467464
]
468465
},
469466
{
@@ -505,7 +502,7 @@
505502
"sim_mean = np.mean(data['sim_data'])\n",
506503
"sim_std = np.std(data['sim_data'])\n",
507504
"\n",
508-
"def preprocessing(forward_dict):\n",
505+
"def configure_input(forward_dict):\n",
509506
" \"\"\"Configures dictionary of prior draws and simulated data into BayesFlow format.\"\"\"\n",
510507
" \n",
511508
" out_dict = {}\n",
@@ -560,9 +557,7 @@
560557
"trainer = Trainer(\n",
561558
" amortizer=amortizer,\n",
562559
" generative_model=model,\n",
563-
" configurator=preprocessing,\n",
564-
" learning_rate=PiecewiseConstantDecay([100000, 150000], [0.001, 0.0001, 0.00001]),\n",
565-
" optional_stopping=False,\n",
560+
" configurator=configure_input\n",
566561
")"
567562
]
568563
},
@@ -610,7 +605,7 @@
610605
"id": "178a1127",
611606
"metadata": {},
612607
"source": [
613-
"Now, we can train our BayesFlow architecture using online learning:"
608+
"Now, we can train our BayesFlow architecture using online learning and showcasing a custom optimizer."
614609
]
615610
},
616611
{
@@ -3421,7 +3416,9 @@
34213416
}
34223417
],
34233418
"source": [
3424-
"losses = trainer.train_online(epochs=200, iterations_per_epoch=1000, batch_size=32)"
3419+
"learning_rate = PiecewiseConstantDecay([100000, 150000], [0.001, 0.0001, 0.00001])\n",
3420+
"optimizer = tf.keras.optimizers.Adam(learning_rate)\n",
3421+
"losses = trainer.train_online(epochs=200, iterations_per_epoch=1000, batch_size=32, optimizer=optimizer)"
34253422
]
34263423
},
34273424
{
@@ -3706,35 +3703,11 @@
37063703
"ax.legend()\n",
37073704
"plt.show()"
37083705
]
3709-
},
3710-
{
3711-
"cell_type": "code",
3712-
"execution_count": 26,
3713-
"id": "69c64757",
3714-
"metadata": {},
3715-
"outputs": [
3716-
{
3717-
"data": {
3718-
"application/javascript": [
3719-
"IPython.notebook.save_notebook()\n"
3720-
],
3721-
"text/plain": [
3722-
"<IPython.core.display.Javascript object>"
3723-
]
3724-
},
3725-
"metadata": {},
3726-
"output_type": "display_data"
3727-
}
3728-
],
3729-
"source": [
3730-
"%%javascript\n",
3731-
"IPython.notebook.save_notebook()"
3732-
]
37333706
}
37343707
],
37353708
"metadata": {
37363709
"kernelspec": {
3737-
"display_name": "Python 3 (ipykernel)",
3710+
"display_name": "Python 3",
37383711
"language": "python",
37393712
"name": "python3"
37403713
},
@@ -3748,7 +3721,7 @@
37483721
"name": "python",
37493722
"nbconvert_exporter": "python",
37503723
"pygments_lexer": "ipython3",
3751-
"version": "3.10.6"
3724+
"version": "3.9.13"
37523725
},
37533726
"toc": {
37543727
"base_numbering": "1",

0 commit comments

Comments
 (0)