Skip to content

Commit 5e55b03

Browse files
committed
move compute_test_quantities to dict_utils, add docs
1 parent b27583c commit 5e55b03

File tree

4 files changed

+69
-65
lines changed

4 files changed

+69
-65
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
import matplotlib.pyplot as plt
55

6-
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_test_quantities
6+
from ...utils.dict_utils import compute_test_quantities
7+
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
78
from ...utils.ecdf import simultaneous_ecdf_bands
89
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
910

bayesflow/diagnostics/plots/plot_quantity.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import matplotlib.pyplot as plt
44
import numpy as np
55

6-
from bayesflow.utils.dict_utils import make_variable_array, dicts_to_arrays, filter_kwargs
6+
from bayesflow.utils.dict_utils import make_variable_array, dicts_to_arrays, filter_kwargs, compute_test_quantities
77
from bayesflow.utils.plot_utils import (
88
add_titles_and_labels,
99
make_figure,
1010
set_layout,
1111
prettify_subplots,
12-
compute_test_quantities,
1312
)
1413
from bayesflow.utils.validators import check_estimates_prior_shapes
1514

@@ -166,15 +165,21 @@ def plot_quantity(
166165

167166
def _prepare_values(
168167
*,
169-
values,
170-
targets,
171-
estimates,
172-
variable_keys,
173-
variable_names,
174-
test_quantities,
175-
label,
176-
default_name,
168+
values: Mapping[str, np.ndarray] | np.ndarray | Callable,
169+
targets: Mapping[str, np.ndarray] | np.ndarray,
170+
estimates: Mapping[str, np.ndarray] | np.ndarray | None,
171+
variable_keys: Sequence[str],
172+
variable_names: Sequence[str],
173+
test_quantities: dict[str, Callable],
174+
label: str | None,
175+
default_name: str,
177176
):
177+
"""
178+
Provate helper function to compute/extract the values required for plotting
179+
a quantity.
180+
181+
Refer to pairs_quantity and plot_quantity for details.
182+
"""
178183
is_values_callable = isinstance(values, Callable)
179184
# Optionally, compute and prepend test quantities from draws
180185
if test_quantities is not None:
@@ -190,7 +195,6 @@ def _prepare_values(
190195
estimates = updated_data["estimates"]
191196
targets = updated_data["targets"]
192197

193-
# input option 3
194198
if estimates is not None:
195199
if is_values_callable:
196200
values = values(estimates=estimates, targets=targets, **filter_kwargs({"aggregation": None}, values))
@@ -210,7 +214,6 @@ def _prepare_values(
210214
if test_quantities is None:
211215
variable_names = variable_names or estimates.variable_names
212216

213-
# input option 2
214217
if all([key in values for key in ["values", "metric_name", "variable_names"]]):
215218
# output of a metric function
216219
label = values["metric_name"] if label is None else label

bayesflow/utils/dict_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,55 @@ def squeeze_inner_estimates_dict(estimates):
344344
return estimates["value"]
345345
else:
346346
return estimates
347+
348+
349+
def compute_test_quantities(
350+
targets: Mapping[str, np.ndarray] | np.ndarray,
351+
estimates: Mapping[str, np.ndarray] | np.ndarray | None = None,
352+
variable_keys: Sequence[str] = None,
353+
variable_names: Sequence[str] = None,
354+
test_quantities: dict[str, Callable] = None,
355+
):
356+
"""Compute additional test quantities for given targets and estimates."""
357+
import keras
358+
359+
test_quantities_estimates = {} if estimates is not None else None
360+
test_quantities_targets = {}
361+
362+
for key, test_quantity_fn in test_quantities.items():
363+
# Apply test_quantity_func to ground-truths
364+
tq_targets = test_quantity_fn(data=targets)
365+
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
366+
367+
if estimates is not None:
368+
# Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
369+
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
370+
flattened_estimates = keras.tree.map_structure(
371+
lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:]))
372+
if isinstance(t, np.ndarray)
373+
else t,
374+
estimates,
375+
)
376+
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
377+
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
378+
379+
# Add custom test quantities to variable keys and names for plotting
380+
# keys and names are set to the test_quantities dict keys
381+
test_quantities_names = list(test_quantities.keys())
382+
383+
if variable_keys is None:
384+
variable_keys = list(estimates.keys() if estimates is not None else targets.keys())
385+
if isinstance(variable_names, list):
386+
variable_names = test_quantities_names + variable_names
387+
388+
variable_keys = test_quantities_names + variable_keys
389+
if estimates is not None:
390+
estimates = test_quantities_estimates | estimates
391+
targets = test_quantities_targets | targets
392+
393+
return {
394+
"variable_keys": variable_keys,
395+
"estimates": estimates,
396+
"targets": targets,
397+
"variable_names": variable_names,
398+
}

bayesflow/utils/plot_utils.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Sequence, Any, Mapping
2-
from collections.abc import Callable
32

43
import numpy as np
54
import matplotlib.pyplot as plt
@@ -427,54 +426,3 @@ def create_legends(
427426
frameon=False,
428427
fontsize=legend_fontsize,
429428
)
430-
431-
432-
def compute_test_quantities(
433-
targets: Mapping[str, np.ndarray] | np.ndarray,
434-
estimates: Mapping[str, np.ndarray] | np.ndarray | None = None,
435-
variable_keys: Sequence[str] = None,
436-
variable_names: Sequence[str] = None,
437-
test_quantities: dict[str, Callable] = None,
438-
):
439-
import keras
440-
441-
test_quantities_estimates = {} if estimates is not None else None
442-
test_quantities_targets = {}
443-
444-
for key, test_quantity_fn in test_quantities.items():
445-
# Apply test_quantity_func to ground-truths
446-
tq_targets = test_quantity_fn(data=targets)
447-
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
448-
449-
if estimates is not None:
450-
# Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
451-
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
452-
flattened_estimates = keras.tree.map_structure(
453-
lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:]))
454-
if isinstance(t, np.ndarray)
455-
else t,
456-
estimates,
457-
)
458-
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
459-
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
460-
461-
# Add custom test quantities to variable keys and names for plotting
462-
# keys and names are set to the test_quantities dict keys
463-
test_quantities_names = list(test_quantities.keys())
464-
465-
if variable_keys is None:
466-
variable_keys = list(estimates.keys() if estimates is not None else targets.keys())
467-
if isinstance(variable_names, list):
468-
variable_names = test_quantities_names + variable_names
469-
470-
variable_keys = test_quantities_names + variable_keys
471-
if estimates is not None:
472-
estimates = test_quantities_estimates | estimates
473-
targets = test_quantities_targets | targets
474-
475-
return {
476-
"variable_keys": variable_keys,
477-
"estimates": estimates,
478-
"targets": targets,
479-
"variable_names": variable_names,
480-
}

0 commit comments

Comments
 (0)