@@ -10,7 +10,7 @@ def root_mean_squared_error(
1010 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1111 variable_keys : Sequence [str ] = None ,
1212 variable_names : Sequence [str ] = None ,
13- normalize : bool = True ,
13+ normalize : str | None = "range" ,
1414 aggregation : Callable = np .median ,
1515) -> dict [str , any ]:
1616 """
@@ -28,8 +28,9 @@ def root_mean_squared_error(
2828 By default, select all keys.
2929 variable_names : Sequence[str], optional (default = None)
3030 Optional variable names to show in the output.
31- normalize : bool, optional (default = True)
32- Whether to normalize the RMSE using the range of the prior samples.
31+ normalize : str or None, optional (default = "range")
32+ Whether to normalize the RMSE using statistics of the prior samples.
33+ Possible options are ("mean", "range", "median", "iqr", "std", None)
3334 aggregation : callable, optional (default = np.median)
3435 Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
3536
@@ -59,13 +60,40 @@ def root_mean_squared_error(
5960 )
6061
6162 rmse = np .sqrt (np .mean ((samples ["estimates" ] - samples ["targets" ][:, None , :]) ** 2 , axis = 0 ))
63+ targets = samples ["targets" ]
6264
63- if normalize :
64- rmse /= (samples ["targets" ].max (axis = 0 ) - samples ["targets" ].min (axis = 0 ))[None , :]
65- metric_name = "NRMSE"
66- else :
67- metric_name = "RMSE"
65+ match normalize :
66+ case None | False :
67+ normalizer = np .array (1.0 )
68+ metric_name = "RMSE"
6869
70+ case "mean" :
71+ normalizer = np .mean (targets , axis = 0 )
72+ metric_name = "NRMSE"
73+
74+ case "median" :
75+ normalizer = np .median (targets , axis = 0 )
76+ metric_name = "NRMSE"
77+
78+ case "range" :
79+ normalizer = targets .max (axis = 0 ) - targets .min (axis = 0 )
80+ metric_name = "NRMSE"
81+
82+ case "std" :
83+ normalizer = np .std (targets , axis = 0 , ddof = 0 )
84+ metric_name = "NRMSE"
85+
86+ case "iqr" :
87+ q75 = np .percentile (targets , 75 , axis = 0 )
88+ q25 = np .percentile (targets , 25 , axis = 0 )
89+ normalizer = q75 - q25
90+ metric_name = "NRMSE"
91+
92+ case _:
93+ raise ValueError (f"Unknown normalization mode: { normalize } " )
94+
95+ rmse /= normalizer [None , ...]
6996 rmse = aggregation (rmse , axis = 0 )
97+
7098 variable_names = samples ["estimates" ].variable_names
7199 return {"values" : rmse , "metric_name" : metric_name , "variable_names" : variable_names }
0 commit comments