Skip to content

Commit 6f21237

Browse files
committed
Allow different normalizations for NRMSE
1 parent 8ef46f6 commit 6f21237

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

bayesflow/diagnostics/metrics/root_mean_squared_error.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def root_mean_squared_error(
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13-
normalize: bool = True,
13+
normalize: str | None = "range",
1414
aggregation: Callable = np.median,
1515
) -> dict[str, any]:
1616
"""
@@ -28,8 +28,9 @@ def root_mean_squared_error(
2828
By default, select all keys.
2929
variable_names : Sequence[str], optional (default = None)
3030
Optional variable names to show in the output.
31-
normalize : bool, optional (default = True)
32-
Whether to normalize the RMSE using the range of the prior samples.
31+
normalize : str or None, optional (default = "range")
32+
Whether to normalize the RMSE using statistics of the prior samples.
33+
Possible options are ("mean", "range", "median", "iqr", "std", None)
3334
aggregation : callable, optional (default = np.median)
3435
Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
3536
@@ -59,13 +60,40 @@ def root_mean_squared_error(
5960
)
6061

6162
rmse = np.sqrt(np.mean((samples["estimates"] - samples["targets"][:, None, :]) ** 2, axis=0))
63+
targets = samples["targets"]
6264

63-
if normalize:
64-
rmse /= (samples["targets"].max(axis=0) - samples["targets"].min(axis=0))[None, :]
65-
metric_name = "NRMSE"
66-
else:
67-
metric_name = "RMSE"
65+
match normalize:
66+
case None | False:
67+
normalizer = np.array(1.0)
68+
metric_name = "RMSE"
6869

70+
case "mean":
71+
normalizer = np.mean(targets, axis=0)
72+
metric_name = "NRMSE"
73+
74+
case "median":
75+
normalizer = np.median(targets, axis=0)
76+
metric_name = "NRMSE"
77+
78+
case "range":
79+
normalizer = targets.max(axis=0) - targets.min(axis=0)
80+
metric_name = "NRMSE"
81+
82+
case "std":
83+
normalizer = np.std(targets, axis=0, ddof=0)
84+
metric_name = "NRMSE"
85+
86+
case "iqr":
87+
q75 = np.percentile(targets, 75, axis=0)
88+
q25 = np.percentile(targets, 25, axis=0)
89+
normalizer = q75 - q25
90+
metric_name = "NRMSE"
91+
92+
case _:
93+
raise ValueError(f"Unknown normalization mode: {normalize}")
94+
95+
rmse /= normalizer[None, ...]
6996
rmse = aggregation(rmse, axis=0)
97+
7098
variable_names = samples["estimates"].variable_names
7199
return {"values": rmse, "metric_name": metric_name, "variable_names": variable_names}

0 commit comments

Comments
 (0)