|
45 | 45 | ], |
46 | 46 | "source": [ |
47 | 47 | "sys.path.append(os.path.abspath(os.path.join('../../..')))\n", |
48 | | - "from bayesflow.forward_inference import Prior, Simulator, GenerativeModel\n", |
| 48 | + "from bayesflow.simulation import Prior, Simulator, GenerativeModel\n", |
49 | 49 | "from bayesflow.networks import InvertibleNetwork\n", |
50 | 50 | "from bayesflow.amortized_inference import AmortizedPosterior\n", |
51 | 51 | "from bayesflow.trainers import Trainer\n", |
|
460 | 460 | "metadata": {}, |
461 | 461 | "outputs": [], |
462 | 462 | "source": [ |
463 | | - "inference_net = InvertibleNetwork({\n", |
464 | | - " 'n_params' : 6,\n", |
465 | | - " 'n_coupling_layers' : 8,\n", |
466 | | - "})" |
| 463 | + "inference_net = InvertibleNetwork(num_params=6, num_coupling_layers=8)" |
467 | 464 | ] |
468 | 465 | }, |
469 | 466 | { |
|
505 | 502 | "sim_mean = np.mean(data['sim_data'])\n", |
506 | 503 | "sim_std = np.std(data['sim_data'])\n", |
507 | 504 | "\n", |
508 | | - "def preprocessing(forward_dict):\n", |
| 505 | + "def configure_input(forward_dict):\n", |
509 | 506 | " \"\"\"Configures dictionary of prior draws and simulated data into BayesFlow format.\"\"\"\n", |
510 | 507 | " \n", |
511 | 508 | " out_dict = {}\n", |
|
560 | 557 | "trainer = Trainer(\n", |
561 | 558 | " amortizer=amortizer,\n", |
562 | 559 | " generative_model=model,\n", |
563 | | - " configurator=preprocessing,\n", |
564 | | - " learning_rate=PiecewiseConstantDecay([100000, 150000], [0.001, 0.0001, 0.00001]),\n", |
565 | | - " optional_stopping=False,\n", |
| 560 | + " configurator=configure_input\n", |
566 | 561 | ")" |
567 | 562 | ] |
568 | 563 | }, |
|
610 | 605 | "id": "178a1127", |
611 | 606 | "metadata": {}, |
612 | 607 | "source": [ |
613 | | - "Now, we can train our BayesFlow architecture using online learning:" |
| 608 | + "Now, we can train our BayesFlow architecture using online learning and showcasing a custom optimizer." |
614 | 609 | ] |
615 | 610 | }, |
616 | 611 | { |
|
3421 | 3416 | } |
3422 | 3417 | ], |
3423 | 3418 | "source": [ |
3424 | | - "losses = trainer.train_online(epochs=200, iterations_per_epoch=1000, batch_size=32)" |
| 3419 | + "learning_rate = PiecewiseConstantDecay([100000, 150000], [0.001, 0.0001, 0.00001])\n", |
| 3420 | + "optimizer = tf.keras.optimizers.Adam(learning_rate)\n", |
| 3421 | + "losses = trainer.train_online(epochs=200, iterations_per_epoch=1000, batch_size=32, optimizer=optimizer)" |
3425 | 3422 | ] |
3426 | 3423 | }, |
3427 | 3424 | { |
|
3706 | 3703 | "ax.legend()\n", |
3707 | 3704 | "plt.show()" |
3708 | 3705 | ] |
3709 | | - }, |
3710 | | - { |
3711 | | - "cell_type": "code", |
3712 | | - "execution_count": 26, |
3713 | | - "id": "69c64757", |
3714 | | - "metadata": {}, |
3715 | | - "outputs": [ |
3716 | | - { |
3717 | | - "data": { |
3718 | | - "application/javascript": [ |
3719 | | - "IPython.notebook.save_notebook()\n" |
3720 | | - ], |
3721 | | - "text/plain": [ |
3722 | | - "<IPython.core.display.Javascript object>" |
3723 | | - ] |
3724 | | - }, |
3725 | | - "metadata": {}, |
3726 | | - "output_type": "display_data" |
3727 | | - } |
3728 | | - ], |
3729 | | - "source": [ |
3730 | | - "%%javascript\n", |
3731 | | - "IPython.notebook.save_notebook()" |
3732 | | - ] |
3733 | 3706 | } |
3734 | 3707 | ], |
3735 | 3708 | "metadata": { |
3736 | 3709 | "kernelspec": { |
3737 | | - "display_name": "Python 3 (ipykernel)", |
| 3710 | + "display_name": "Python 3", |
3738 | 3711 | "language": "python", |
3739 | 3712 | "name": "python3" |
3740 | 3713 | }, |
|
3748 | 3721 | "name": "python", |
3749 | 3722 | "nbconvert_exporter": "python", |
3750 | 3723 | "pygments_lexer": "ipython3", |
3751 | | - "version": "3.10.6" |
| 3724 | + "version": "3.9.13" |
3752 | 3725 | }, |
3753 | 3726 | "toc": { |
3754 | 3727 | "base_numbering": "1", |
|
0 commit comments