Skip to content

Commit f61a023

Browse files
committed
rename 's' to 'markersize', add tests
1 parent 241105b commit f61a023

File tree

3 files changed

+120
-21
lines changed

3 files changed

+120
-21
lines changed

bayesflow/diagnostics/plots/pairs_quantity.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def pairs_quantity(
2222
height: float = 2.5,
2323
cmap: str | matplotlib.colors.Colormap = "viridis",
2424
alpha: float = 0.9,
25-
s: float = 8.0,
25+
markersize: float = 8.0,
2626
marker: str = "o",
2727
label: str = None,
2828
label_fontsize: int = 14,
@@ -81,6 +81,7 @@ def pairs_quantity(
8181
test_quantities : dict or None, optional, default: None
8282
A dict that maps plot titles to functions that compute
8383
test quantities based on estimate/target draws.
84+
Can only be supplied if `values` is a function.
8485
8586
The dict keys are automatically added to ``variable_keys``
8687
and ``variable_names``.
@@ -96,7 +97,7 @@ def pairs_quantity(
9697
The colormap for the plot.
9798
alpha : float in [0, 1], optional, default: 0.9
9899
The opacity of the plot
99-
s : float, optional, default: 8.0
100+
markersize : float, optional, default: 8.0
100101
The marker size in points**2 for the scatter plot.
101102
marker : str, optional, default: 'o'
102103
The marker for the scatter plot.
@@ -139,7 +140,13 @@ def pairs_quantity(
139140
"""
140141

141142
if isinstance(values, Callable) and estimates is None:
142-
raise ValueError("Supplied a callable as `values`, but not `estimates`.")
143+
raise ValueError("Supplied a callable as `values`, but no `estimates`.")
144+
if not isinstance(values, Callable) and test_quantities is not None:
145+
raise ValueError(
146+
"Supplied `test_quantities`, but `values` is not a function. "
147+
"As the values have to be calculated for the test quantities, "
148+
"passing a function is required."
149+
)
143150

144151
d = _prepare_values(
145152
values=values,
@@ -188,7 +195,7 @@ def pairs_quantity(
188195
values[:, i],
189196
c=row_values,
190197
cmap=cmap,
191-
s=s,
198+
s=markersize,
192199
marker=marker,
193200
vmin=vmin,
194201
vmax=vmax,
@@ -213,7 +220,7 @@ def pairs_quantity(
213220
targets[:, i],
214221
c=row_values,
215222
cmap=cmap,
216-
s=s,
223+
s=markersize,
217224
vmin=vmin,
218225
vmax=vmax,
219226
alpha=alpha,
@@ -252,4 +259,4 @@ def inches_to_figure(fig, values):
252259
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
253260
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
254261

255-
return g.figure
262+
return g

bayesflow/diagnostics/plots/plot_quantity.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def plot_quantity(
2626
title_fontsize: int = 18,
2727
tick_fontsize: int = 12,
2828
color: str = "#132a70",
29-
s: float = 25.0,
29+
markersize: float = 25.0,
3030
marker: str = "o",
3131
alpha: float = 0.5,
3232
xlabel: str = "Ground truth",
@@ -74,6 +74,7 @@ def plot_quantity(
7474
test_quantities : dict or None, optional, default: None
7575
A dict that maps plot titles to functions that compute
7676
test quantities based on estimate/target draws.
77+
Can only be supplied if `values` is a function.
7778
7879
The dict keys are automatically added to ``variable_keys``
7980
and ``variable_names``.
@@ -93,7 +94,7 @@ def plot_quantity(
9394
The font size of the axis ticklabels
9495
color : str, optional, default: '#8f2727'
9596
The color for the true vs. estimated scatter points and error bars
96-
s : float, optional, default: 25.0
97+
markersize : float, optional, default: 25.0
9798
The marker size in points**2 for the scatter plot.
9899
marker : str, optional, default: 'o'
99100
The marker for the scatter plot.
@@ -117,7 +118,13 @@ def plot_quantity(
117118
"""
118119

119120
if isinstance(values, Callable) and estimates is None:
120-
raise ValueError("Supplied a callable as `values`, but not `estimates`.")
121+
raise ValueError("Supplied a callable as `values`, but no `estimates`.")
122+
if not isinstance(values, Callable) and test_quantities is not None:
123+
raise ValueError(
124+
"Supplied `test_quantities`, but `values` is not a function. "
125+
"As the values have to be calculated for the test quantities, "
126+
"passing a function is required."
127+
)
121128

122129
d = _prepare_values(
123130
values=values,
@@ -152,7 +159,7 @@ def plot_quantity(
152159
if i >= num_variables:
153160
break
154161

155-
ax.scatter(targets[:, i], values[:, i], color=color, alpha=alpha, s=s, marker=marker)
162+
ax.scatter(targets[:, i], values[:, i], color=color, alpha=alpha, s=markersize, marker=marker)
156163

157164
prettify_subplots(axes, num_subplots=num_variables, tick_fontsize=tick_fontsize)
158165

@@ -242,16 +249,10 @@ def _prepare_values(
242249
default_name=default_name,
243250
)
244251
except ValueError:
245-
if test_quantities is not None and not is_values_callable:
246-
raise ValueError(
247-
"`test_quantities` requires specifying `values` as callable and passing `estimates "
248-
"to enable the computation of the values for each test quantity."
249-
)
250252
raise ValueError(
251253
"Length of 'variable_names' and number of variables do not match. "
252254
"Did you forget to specify `variable_keys`?"
253255
)
254-
255256
variable_names = targets.variable_names
256257
variable_keys = targets.variable_keys
257258

@@ -266,11 +267,6 @@ def _prepare_values(
266267
default_name=default_name,
267268
)
268269
except ValueError:
269-
if test_quantities is not None and not is_values_callable:
270-
raise ValueError(
271-
"`test_quantities` requires specifying `values` as callable and passing `estimates "
272-
"to enable the computation of the values for each test quantity."
273-
)
274270
raise ValueError(
275271
"Length of 'variable_names' and number of variables do not match. "
276272
"Did you forget to specify `variable_keys`?"

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,102 @@ def test_pairs_posterior(random_estimates, random_targets, random_priors):
138138
)
139139

140140

141+
def test_pairs_quantity(random_estimates, random_targets, random_priors):
142+
# test test_quantities and label assignment
143+
key = next(iter(random_estimates.keys()))
144+
test_quantities = {
145+
"a": lambda data: np.sum(data[key], axis=-1),
146+
"b": lambda data: np.prod(data[key], axis=-1),
147+
}
148+
out = bf.diagnostics.plots.pairs_quantity(
149+
values=bf.diagnostics.posterior_contraction,
150+
estimates=random_estimates,
151+
targets=random_targets,
152+
test_quantities=test_quantities,
153+
)
154+
155+
num_vars = num_variables(random_estimates) + len(test_quantities)
156+
assert out.axes.shape == (num_vars, num_vars)
157+
assert out.axes[0, 0].get_ylabel() == "a"
158+
assert out.axes[2, 0].get_ylabel() == "beta_0"
159+
assert out.axes[4, 4].get_xlabel() == "sigma"
160+
161+
values = bf.diagnostics.posterior_contraction(estimates=random_estimates, targets=random_targets, aggregation=None)
162+
163+
bf.diagnostics.plots.pairs_quantity(
164+
values,
165+
targets=random_targets,
166+
)
167+
168+
raw_values = np.random.normal(size=values["values"].shape)
169+
out = bf.diagnostics.plots.pairs_quantity(raw_values, targets=random_targets, variable_keys=["beta", "sigma"])
170+
assert out.axes.shape == (3, 3)
171+
172+
with pytest.raises(ValueError):
173+
bf.diagnostics.plots.pairs_quantity(raw_values, targets=random_targets)
174+
175+
with pytest.raises(ValueError):
176+
bf.diagnostics.plots.pairs_quantity(
177+
values=values,
178+
estimates=random_estimates,
179+
targets=random_targets,
180+
test_quantities=test_quantities,
181+
)
182+
183+
with pytest.raises(ValueError):
184+
bf.diagnostics.plots.pairs_quantity(
185+
values=bf.diagnostics.posterior_contraction,
186+
targets=random_targets,
187+
)
188+
189+
190+
def test_plot_quantity(random_estimates, random_targets, random_priors):
191+
# test test_quantities and label assignment
192+
key = next(iter(random_estimates.keys()))
193+
test_quantities = {
194+
"a": lambda data: np.sum(data[key], axis=-1),
195+
"b": lambda data: np.prod(data[key], axis=-1),
196+
}
197+
out = bf.diagnostics.plots.plot_quantity(
198+
values=bf.diagnostics.posterior_contraction,
199+
estimates=random_estimates,
200+
targets=random_targets,
201+
test_quantities=test_quantities,
202+
)
203+
204+
num_vars = num_variables(random_estimates) + len(test_quantities)
205+
assert len(out.axes) == num_vars
206+
assert out.axes[0].title._text == "a"
207+
208+
values = bf.diagnostics.posterior_contraction(estimates=random_estimates, targets=random_targets, aggregation=None)
209+
210+
bf.diagnostics.plots.plot_quantity(
211+
values,
212+
targets=random_targets,
213+
)
214+
215+
raw_values = np.random.normal(size=values["values"].shape)
216+
out = bf.diagnostics.plots.plot_quantity(raw_values, targets=random_targets, variable_keys=["beta", "sigma"])
217+
assert len(out.axes) == 3
218+
219+
with pytest.raises(ValueError):
220+
bf.diagnostics.plots.plot_quantity(raw_values, targets=random_targets)
221+
222+
with pytest.raises(ValueError):
223+
bf.diagnostics.plots.plot_quantity(
224+
values=values,
225+
estimates=random_estimates,
226+
targets=random_targets,
227+
test_quantities=test_quantities,
228+
)
229+
230+
with pytest.raises(ValueError):
231+
bf.diagnostics.plots.plot_quantity(
232+
values=bf.diagnostics.posterior_contraction,
233+
targets=random_targets,
234+
)
235+
236+
141237
def test_mc_calibration(pred_models, true_models, model_names):
142238
out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names)
143239
assert len(out.axes) == pred_models.shape[-1]

0 commit comments

Comments
 (0)