Skip to content

Commit 16102ba

Browse files
author
WeatherBenchX authors
committed
Add kwarg to override REV probability thresholds
PiperOrigin-RevId: 824491762
1 parent aa39226 commit 16102ba

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

weatherbenchX/metrics/probabilistic.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -788,22 +788,43 @@ class RelativeEconomicValue(base.PerVariableMetric):
788788
"""Relative economic value.
789789
790790
This metric assumes that the targets are a binary and the predictions
791-
are probabilities between in [0, 1]. It computes REV across all possible
792-
decision thresholds for a given ensemble size.
791+
are probabilities between in [0, 1].
792+
793+
Defaults to 51 cost-loss ratios spaced uniformly on a log scale from 0.005 to
794+
1, unless overridden by setting cost_loss_ratios directly.
795+
796+
By default, REV is computed across all possible decision thresholds for a
797+
given ensemble size. However, it is also possible to specify the desired
798+
probability thresholds. This is useful if you only want to compute REV for a
799+
subset of all possible probability thresholds.
800+
801+
TODO(tomandersson, matthjw): Add a helper method to ensure overridden
802+
probability thresholds are a subset of the between-each-ensemble-member-count
803+
thresholds, e.g. by snapping-to-nearest then deduplicating.
793804
"""
794805

795806
def __init__(
796-
self, ensemble_size: int, cost_loss_ratios: Optional[np.ndarray] = None
807+
self,
808+
ensemble_size: Optional[int] = None,
809+
cost_loss_ratios: Optional[np.ndarray] = None,
810+
probability_thresholds: Optional[np.ndarray] = None,
811+
statistic_suffix: Optional[str] = None,
797812
):
798813

799-
thresholds = (np.arange(ensemble_size) + 0.5) / ensemble_size
800-
801-
self._thresholds = xr.DataArray(
802-
thresholds, dims=['threshold'], coords={'threshold': thresholds}
803-
)
804-
# TODO(tomandersson): Make this configurable when thresholds themselves are
805-
# configurable.
806-
self._unique_name_suffix = 'all_thresholds_for_ensemble_size'
814+
if ensemble_size is None and probability_thresholds is None:
815+
raise ValueError(
816+
'Either ensemble_size or probability_thresholds must be specified.'
817+
)
818+
if probability_thresholds is not None and ensemble_size is not None:
819+
raise ValueError(
820+
'Only one of ensemble_size or probability_thresholds must be'
821+
' specified.'
822+
)
823+
if probability_thresholds is not None and statistic_suffix is None:
824+
raise ValueError(
825+
'If probability_thresholds is specified, statistic_suffix must be'
826+
' specified.'
827+
)
807828

808829
if cost_loss_ratios is None:
809830
cost_loss_ratios = np.geomspace(0.005, 1, 51)[:-1]
@@ -814,6 +835,21 @@ def __init__(
814835
coords={'cost_loss_ratio': cost_loss_ratios},
815836
)
816837

838+
self._thresholds = probability_thresholds
839+
if self._thresholds is None:
840+
self._thresholds = (np.arange(ensemble_size) + 0.5) / ensemble_size
841+
if statistic_suffix is None:
842+
statistic_suffix = 'all_thresholds_for_ensemble_size'
843+
if not np.all(self._thresholds >= 0.0) or not np.all(
844+
self._thresholds <= 1.0
845+
):
846+
raise ValueError(
847+
'Probability thresholds must be in [0, 1], got'
848+
f' {self._thresholds=}.'
849+
)
850+
851+
self._unique_name_suffix = statistic_suffix or ''
852+
817853
@property
818854
def statistics(self) -> Mapping[str, base.Statistic]:
819855
binarize = wrappers.ContinuousToBinary(

0 commit comments

Comments
 (0)