Skip to content

Commit 3f28f34

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
2 parents cd45b85 + 449a79a commit 3f28f34

File tree

10 files changed

+160
-46
lines changed

10 files changed

+160
-46
lines changed

bayesflow/approximators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
r"""
22
A collection of :py:class:`~bayesflow.approximators.Approximator`\ s, which embody the inference task and the
33
neural network components used to perform it.
44
"""

bayesflow/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
r"""
22
A collection of `keras.utils.PyDataset <https://keras.io/api/utils/python_utils/#pydataset-class>`__\ s, which
33
wrap your data-generating process (i.e., your :py:class:`~bayesflow.simulators.Simulator`) and thus determine the
44
effective training strategy (e.g., online or offline).

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import seaborn as sns
88

99
from bayesflow.utils.dict_utils import dicts_to_arrays
10+
from bayesflow.utils.plot_utils import create_legends
1011

1112
from .pairs_samples import _pairs_samples
1213

@@ -21,6 +22,7 @@ def pairs_posterior(
2122
height: int = 3,
2223
post_color: str | tuple = "#132a70",
2324
prior_color: str | tuple = "gray",
25+
target_color: str | tuple = "red",
2426
alpha: float = 0.9,
2527
label_fontsize: int = 14,
2628
tick_fontsize: int = 12,
@@ -37,25 +39,27 @@ def pairs_posterior(
3739
Optional true parameter values that have generated the observed dataset.
3840
priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
3941
Optional prior samples obtained from the prior.
40-
dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated.
41-
Should only be specified if estimates contains posterior draws from multiple datasets.
42+
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
43+
Should only be specified if estimates contain posterior draws from multiple datasets.
4244
variable_keys : list or None, optional, default: None
4345
Select keys from the dictionary provided in samples.
4446
By default, select all keys.
4547
variable_names : list or None, optional, default: None
4648
The parameter names for nice plot titles. Inferred if None
4749
height : float, optional, default: 3
48-
The height of the pairplot
50+
The height of the pair plots
4951
label_fontsize : int, optional, default: 14
5052
The font size of the x and y-label texts (parameter names)
5153
tick_fontsize : int, optional, default: 12
52-
The font size of the axis ticklabels
54+
The font size of the axis tick labels
5355
legend_fontsize : int, optional, default: 16
5456
The font size of the legend text
5557
post_color : str, optional, default: '#132a70'
5658
The color for the posterior histograms and KDEs
5759
prior_color : str, optional, default: gray
5860
The color for the optional prior histograms and KDEs
61+
target_color : str, optional, default: red
62+
The color for the optional true parameter lines and points
5963
alpha : float in [0, 1], optional, default: 0.9
6064
The opacity of the posterior plots
6165
@@ -81,7 +85,7 @@ def pairs_posterior(
8185
variable_names=variable_names,
8286
)
8387

84-
# dicts_to_arrays will keep dataset axis even if it is of length 1
88+
# dicts_to_arrays will keep the dataset axis even if it is of length 1
8589
# however, pairs plotting requires the dataset axis to be removed
8690
estimates_shape = plot_data["estimates"].shape
8791
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
@@ -109,14 +113,30 @@ def pairs_posterior(
109113
# Create DataFrame with variable names as columns
110114
g.data = pd.DataFrame(targets, columns=targets.variable_names)
111115
g.data["_source"] = "True Parameter"
112-
g.map_diag(plot_true_params)
116+
g.map_diag(plot_true_params_as_lines, color=target_color)
117+
g.map_offdiag(plot_true_params_as_points, color=target_color)
118+
119+
create_legends(
120+
g,
121+
plot_data,
122+
color=post_color,
123+
color2=prior_color,
124+
legend_fontsize=legend_fontsize,
125+
show_single_legend=False,
126+
)
113127

114128
return g
115129

116130

117-
def plot_true_params(x, hue=None, **kwargs):
118-
"""Custom function to plot true parameters on the diagonal."""
131+
def plot_true_params_as_lines(x, hue=None, color=None, **kwargs):
132+
"""Custom function to plot true parameters on the diagonal as dashed lines."""
119133
# hue needs to be added to handle the case of plotting both posterior and prior
120134
param = x.iloc[0] # Get the single true value for the diagonal
121135
# only plot on the diagonal a vertical line for the true parameter
122-
plt.axvline(param, color="black", linestyle="--")
136+
plt.axvline(param, color=color, linestyle="--")
137+
138+
139+
def plot_true_params_as_points(x, y, color=None, marker="x", **kwargs):
140+
"""Custom function to plot true parameters on the off-diagonal as a single point."""
141+
if len(x) > 0 and len(y) > 0:
142+
plt.scatter(x.iloc[0], y.iloc[0], color=color, marker=marker, **kwargs)

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from bayesflow.utils import logging
1010
from bayesflow.utils.dict_utils import dicts_to_arrays
11+
from bayesflow.utils.plot_utils import create_legends
1112

1213

1314
def pairs_samples(
@@ -17,8 +18,10 @@ def pairs_samples(
1718
height: float = 2.5,
1819
color: str | tuple = "#132a70",
1920
alpha: float = 0.9,
21+
label: str = "Posterior",
2022
label_fontsize: int = 14,
2123
tick_fontsize: int = 12,
24+
show_single_legend: bool = False,
2225
**kwargs,
2326
) -> sns.PairGrid:
2427
"""
@@ -37,13 +40,18 @@ def pairs_samples(
3740
height : float, optional, default: 2.5
3841
The height of the pair plot
3942
color : str, optional, default : '#8f2727'
40-
The color of the plot
43+
The primary color of the plot
4144
alpha : float in [0, 1], optional, default: 0.9
4245
The opacity of the plot
46+
label : str, optional, default: "Posterior"
47+
Label for the dataset to plot
4348
label_fontsize : int, optional, default: 14
4449
The font size of the x and y-label texts (parameter names)
4550
tick_fontsize : int, optional, default: 12
46-
The font size of the axis ticklabels
51+
The font size of the axis tick labels
52+
show_single_legend : bool, optional, default: False
53+
Optional toggle for the user to choose whether a single dataset
54+
should also display legend
4755
**kwargs : dict, optional
4856
Additional keyword arguments passed to the sns.PairGrid constructor
4957
"""
@@ -59,8 +67,11 @@ def pairs_samples(
5967
height=height,
6068
color=color,
6169
alpha=alpha,
70+
label=label,
6271
label_fontsize=label_fontsize,
6372
tick_fontsize=tick_fontsize,
73+
show_single_legend=show_single_legend,
74+
**kwargs,
6475
)
6576

6677
return g
@@ -72,17 +83,27 @@ def _pairs_samples(
7283
color: str | tuple = "#132a70",
7384
color2: str | tuple = "gray",
7485
alpha: float = 0.9,
86+
label: str = "Posterior",
7587
label_fontsize: int = 14,
7688
tick_fontsize: int = 12,
7789
legend_fontsize: int = 14,
90+
show_single_legend: bool = False,
7891
**kwargs,
7992
) -> sns.PairGrid:
80-
# internal version of pairs_samples creating the seaborn plot
93+
"""
94+
Internal version of pairs_samples creating the seaborn PairPlot
95+
for both a single dataset and multiple datasets.
8196
82-
# Parameters
83-
# ----------
84-
# plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
85-
# other arguments are documented in pairs_samples
97+
Parameters
98+
----------
99+
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100+
Formatted data to plot from the sample dataset
101+
color2 : str, optional, default: 'gray'
102+
Secondary color for the pair plots.
103+
This is the color used for the prior draws.
104+
105+
Other arguments are documented in pairs_samples
106+
"""
86107

87108
estimates_shape = plot_data["estimates"].shape
88109
if len(estimates_shape) != 2:
@@ -136,7 +157,7 @@ def _pairs_samples(
136157
common_norm=False,
137158
)
138159

139-
# add scatterplots to the upper diagonal
160+
# add scatter plots to the upper diagonal
140161
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
141162

142163
# add KDEs to the lower diagonal
@@ -146,11 +167,6 @@ def _pairs_samples(
146167
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
147168
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
148169

149-
# need to add legend here such that colors are recognized
150-
if plot_data["priors"] is not None:
151-
g.add_legend(fontsize=legend_fontsize, loc="center right")
152-
g._legend.set_title(None)
153-
154170
# Generate grids
155171
dim = g.axes.shape[0]
156172
for i in range(dim):
@@ -165,32 +181,48 @@ def _pairs_samples(
165181
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
166182
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
167183

168-
# adjust font size of labels
184+
# adjust the font size of labels
169185
# the labels themselves remain the same as before, i.e., variable_names
170186
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
171187
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
172188

189+
# need to add legend here such that colors are recognized
190+
# if plot_data["priors"] is not None:
191+
# g.add_legend(fontsize=legend_fontsize, loc="center right")
192+
# g._legend.set_title(None)
193+
194+
create_legends(
195+
g,
196+
plot_data,
197+
color=color,
198+
color2=color2,
199+
legend_fontsize=legend_fontsize,
200+
label=label,
201+
show_single_legend=show_single_legend,
202+
)
203+
173204
# Return figure
174205
g.tight_layout()
175206

176207
return g
177208

178209

179-
# create a histogram plot on a twin y axis
180-
# this ensures that the y scaling of the diagonal plots
181-
# in independent of the y scaling of the off-diagonal plots
182210
def histplot_twinx(x, **kwargs):
183-
# Create a twin axis
184-
ax2 = plt.gca().twinx()
211+
"""
212+
# create a histogram plot on a twin y-axis
213+
# this ensures that the y scaling of the diagonal plots
214+
# in independent of the y scaling of the off-diagonal plots
185215
216+
Parameters
217+
----------
218+
x : np.ndarray
219+
Data to be plotted.
220+
"""
186221
# create a histogram on the twin axis
187-
sns.histplot(x, **kwargs, ax=ax2)
222+
sns.histplot(x, legend=False, **kwargs)
188223

189224
# make the twin axis invisible
190225
plt.gca().spines["right"].set_visible(False)
191226
plt.gca().spines["top"].set_visible(False)
192-
ax2.set_ylabel("")
193-
ax2.set_yticks([])
194-
ax2.set_yticklabels([])
195227

196228
return None

bayesflow/distributions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
r"""
22
A collection of :py:class:`~bayesflow.distributions.Distribution`\ s,
33
which represent the latent space for :py:class:`~bayesflow.networks.InferenceNetwork`\ s
44
or the summary space of :py:class:`~bayesflow.networks.SummaryNetwork`\ s.

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class FreeFormFlow(InferenceNetwork):
3737
"activation": "mish",
3838
"kernel_initializer": "he_normal",
3939
"residual": True,
40-
"dropout": 0.05,
40+
"dropout": 0.0,
4141
"spectral_normalization": False,
4242
}
4343

@@ -46,7 +46,7 @@ class FreeFormFlow(InferenceNetwork):
4646
"activation": "mish",
4747
"kernel_initializer": "he_normal",
4848
"residual": True,
49-
"dropout": 0.05,
49+
"dropout": 0.0,
5050
"spectral_normalization": False,
5151
}
5252

@@ -219,7 +219,7 @@ def decode(z):
219219

220220
# VJP computation
221221
z, vjp_fn = vjp(encode, x, return_output=True)
222-
v1 = vjp_fn(v)[0]
222+
v1 = vjp_fn(v)
223223
# JVP computation
224224
x_pred, v2 = jvp(decode, (z,), (v,), return_output=True)
225225

bayesflow/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
r"""
22
A collection of `keras.Metric <https://keras.io/api/metrics/base_metric/#metric-class>`__\ s for evaluating the
33
performance of models.
44
"""

bayesflow/networks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
r"""
22
A rich collection of neural network architectures for use in :py:class:`~bayesflow.approximators.Approximator`\ s.
33
44
The module features inference networks (IN), summary networks (SN), as well as general purpose networks.

bayesflow/scores/scoring_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def get_head(self, key: str, output_shape: Shape) -> keras.Sequential:
165165
return keras.Sequential([subnet, dense, reshape, link])
166166

167167
def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) -> Tensor:
168-
"""Scores a batch of probabilistic estimates of distributions based on samples
168+
r"""Scores a batch of probabilistic estimates of distributions based on samples
169169
of the corresponding distributions.
170170
171171
Parameters

0 commit comments

Comments
 (0)