Skip to content

Commit 45da505

Browse files
committed
Merge branch 'Development' of https://github.com/stefanradev93/BayesFlow into Development
2 parents 4d415e6 + c76304a commit 45da505

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

bayesflow/diagnostics.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def plot_recovery(
5151
color="#8f2727",
5252
n_col=None,
5353
n_row=None,
54+
xlabel="Ground truth",
55+
ylabel="Estimated",
5456
):
5557
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5658
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -96,7 +98,11 @@ def plot_recovery(
9698
A flag for adding R^2 between true and estimates to the plot
9799
color : str, optional, default: '#8f2727'
98100
The color for the true vs. estimated scatter points and error bars
99-
101+
xlabel : str, optional, default: 'Ground truth'
102+
The label on the x-axis of the plot
103+
ylabel : str, optional, default: 'Estimated'
104+
The label on the y-axis of the plot
105+
100106
Returns
101107
-------
102108
f : plt.Figure - the figure instance for optional saving
@@ -198,15 +204,15 @@ def plot_recovery(
198204
# Only add x-labels to the bottom row
199205
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
200206
for _ax in bottom_row:
201-
_ax.set_xlabel("Ground truth", fontsize=label_fontsize)
207+
_ax.set_xlabel(xlabel, fontsize=label_fontsize)
202208

203209
# Only add y-labels to right left-most row
204210
if n_row == 1: # if there is only one row, the ax array is 1D
205-
axarr[0].set_ylabel("Estimated", fontsize=label_fontsize)
211+
axarr[0].set_ylabel(ylabel, fontsize=label_fontsize)
206212
# If there is more than one row, the ax array is 2D
207213
else:
208214
for _ax in axarr[:, 0]:
209-
_ax.set_ylabel("Estimated", fontsize=label_fontsize)
215+
_ax.set_ylabel(ylabel, fontsize=label_fontsize)
210216

211217
# Remove unused axes entirely
212218
for _ax in axarr_it[n_params:]:

bayesflow/summary_networks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148

149149
# Final output reduces representation into a vector of length summary_dim
150150
self.output_layer = Dense(summary_dim)
151+
self.summary_dim = summary_dim
151152

152153
def call(self, x, **kwargs):
153154
"""Performs the forward pass through the transformer.
@@ -269,6 +270,8 @@ def __init__(
269270
summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_seeds
270271
)
271272

273+
self.summary_dim = summary_dim
274+
272275
def call(self, x, **kwargs):
273276
"""Performs the forward pass through the set-transformer.
274277

examples/Model_Comparison_MPT.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@
699699
"cell_type": "markdown",
700700
"metadata": {},
701701
"source": [
702-
"Our simple simulators are extremly fast, so we can use online training (simulating the data on the fly during training). Here, we use 5 epochs with 500 iterations each and a batch size of 64 simulations. This means that we use $10 \\times 500 \\times 64 = 320,000$ unique simulations in total for training our neural network. We can do this, because the simulators are trivial to implement and thus very efficient to run. Training should take a couple of seconds to complete."
702+
"Our simple simulators are extremly fast, so we can use online training (simulating the data on the fly during training). Here, we use $10$ epochs with $500$ iterations each and a batch size of $64$ simulations. This means that we use $10 \\times 500 \\times 64 = 320,000$ unique simulations in total for training our neural network. We can do this because the simulators are trivial to implement and thus very efficient to run. Training should take a couple of seconds to complete."
703703
]
704704
},
705705
{
@@ -937,7 +937,7 @@
937937
],
938938
"source": [
939939
"# Way 1: Amortizer with dictionary input\n",
940-
"amortizer.posterior_probs({\"summary_conditions\" : fake_data})[0]"
940+
"amortizer.posterior_probs({\"summary_conditions\": fake_data})[0]"
941941
]
942942
},
943943
{
@@ -1016,7 +1016,7 @@
10161016
"name": "python",
10171017
"nbconvert_exporter": "python",
10181018
"pygments_lexer": "ipython3",
1019-
"version": "3.10.12"
1019+
"version": "3.10.10"
10201020
},
10211021
"toc": {
10221022
"base_numbering": 1,

0 commit comments

Comments
 (0)